diff --git a/.github/actions/go-test-setup/action.yml b/.github/actions/go-test-setup/action.yml new file mode 100644 index 00000000..6b15ea06 --- /dev/null +++ b/.github/actions/go-test-setup/action.yml @@ -0,0 +1,25 @@ +name: Go Test Setup +description: Set up the environment for go test +runs: + using: "composite" + steps: + - name: Common setup + shell: bash + run: | + echo 'CGO_ENABLED=1' >> $GITHUB_ENV + - name: Windows setup + shell: bash + if: ${{ runner.os == 'Windows' }} + run: | + pacman -S --noconfirm mingw-w64-x86_64-toolchain mingw-w64-i686-toolchain + echo '/c/msys64/mingw64/bin' >> $GITHUB_PATH + echo 'PATH_386=/c/msys64/mingw32/bin:${{ env.PATH_386 }}' >> $GITHUB_ENV + - name: Linux setup + shell: bash + if: ${{ runner.os == 'Linux' }} + run: | + sudo apt-get install gcc-multilib + sudo dpkg --add-architecture i386 + sudo apt-get update + sudo apt-get install libssl-dev:i386 + echo 'CC_FOR_linux_386=i686-w64-mingw32-gcc' diff --git a/.github/workflows/automerge.yml b/.github/workflows/automerge.yml new file mode 100644 index 00000000..3833fc22 --- /dev/null +++ b/.github/workflows/automerge.yml @@ -0,0 +1,11 @@ +# File managed by web3-bot. DO NOT EDIT. +# See https://github.com/protocol/.github/ for details. + +name: Automerge +on: [ pull_request ] + +jobs: + automerge: + uses: protocol/.github/.github/workflows/automerge.yml@master + with: + job: 'automerge' diff --git a/.github/workflows/go-check.yml b/.github/workflows/go-check.yml new file mode 100644 index 00000000..251f7faa --- /dev/null +++ b/.github/workflows/go-check.yml @@ -0,0 +1,73 @@ +# File managed by web3-bot. DO NOT EDIT. +# See https://github.com/protocol/.github/ for details. + +on: [push, pull_request] +name: Go Checks + +jobs: + unit: + runs-on: ubuntu-latest + name: All + env: + RUNGOGENERATE: false + steps: + - uses: actions/checkout@v3 + with: + submodules: recursive + - uses: actions/setup-go@v3 + with: + go-version: "1.19.x" + - name: Run repo-specific setup + uses: ./.github/actions/go-check-setup + if: hashFiles('./.github/actions/go-check-setup') != '' + - name: Read config + if: hashFiles('./.github/workflows/go-check-config.json') != '' + run: | + if jq -re .gogenerate ./.github/workflows/go-check-config.json; then + echo "RUNGOGENERATE=true" >> $GITHUB_ENV + fi + - name: Install staticcheck + run: go install honnef.co/go/tools/cmd/staticcheck@376210a89477dedbe6fdc4484b233998650d7b3c # 2022.1.3 (v0.3.3) + - name: Check that go.mod is tidy + uses: protocol/multiple-go-modules@v1.2 + with: + run: | + go mod tidy + if [[ -n $(git ls-files --other --exclude-standard --directory -- go.sum) ]]; then + echo "go.sum was added by go mod tidy" + exit 1 + fi + git diff --exit-code -- go.sum go.mod + - name: gofmt + if: ${{ success() || failure() }} # run this step even if the previous one failed + run: | + out=$(gofmt -s -l .) + if [[ -n "$out" ]]; then + echo $out | awk '{print "::error file=" $0 ",line=0,col=0::File is not gofmt-ed."}' + exit 1 + fi + - name: go vet + if: ${{ success() || failure() }} # run this step even if the previous one failed + uses: protocol/multiple-go-modules@v1.2 + with: + run: go vet ./... + - name: staticcheck + if: ${{ success() || failure() }} # run this step even if the previous one failed + uses: protocol/multiple-go-modules@v1.2 + with: + run: | + set -o pipefail + staticcheck ./... | sed -e 's@\(.*\)\.go@./\1.go@g' + - name: go generate + uses: protocol/multiple-go-modules@v1.2 + if: (success() || failure()) && env.RUNGOGENERATE == 'true' + with: + run: | + git clean -fd # make sure there aren't untracked files / directories + go generate ./... + # check if go generate modified or added any files + if ! $(git add . && git diff-index HEAD --exit-code --quiet); then + echo "go generated caused changes to the repository:" + git status --short + exit 1 + fi diff --git a/.github/workflows/go-test-ubuntu-22.04.yml b/.github/workflows/go-test-ubuntu-22.04.yml new file mode 100644 index 00000000..cb086365 --- /dev/null +++ b/.github/workflows/go-test-ubuntu-22.04.yml @@ -0,0 +1,69 @@ +# See: +# https://github.com/libp2p/go-openssl/pull/25 +# https://github.com/protocol/.github/issues/349 +# for details. +on: [push, pull_request] +name: Go Test + +jobs: + unit: + strategy: + fail-fast: false + matrix: + os: [ "ubuntu-22.04" ] + go: [ "1.17.x", "1.18.x" ] + env: + COVERAGES: "" + runs-on: ${{ matrix.os }} + name: ${{ matrix.os }} (go ${{ matrix.go }}) + steps: + - uses: actions/checkout@v2 + with: + submodules: recursive + - uses: actions/setup-go@v2 + with: + go-version: ${{ matrix.go }} + - name: Go information + run: | + go version + go env + - name: Use msys2 on windows + if: startsWith(matrix.os, 'windows') + shell: bash + # The executable for msys2 is also called bash.cmd + # https://github.com/actions/virtual-environments/blob/main/images/win/Windows2019-Readme.md#shells + # If we prepend its location to the PATH + # subsequent 'shell: bash' steps will use msys2 instead of gitbash + run: echo "C:/msys64/usr/bin" >> $GITHUB_PATH + - name: Run repo-specific setup + uses: ./.github/actions/go-test-setup + if: hashFiles('./.github/actions/go-test-setup') != '' + - name: Run tests + uses: protocol/multiple-go-modules@v1.2 + with: + # Use -coverpkg=./..., so that we include cross-package coverage. + # If package ./A imports ./B, and ./A's tests also cover ./B, + # this means ./B's coverage will be significantly higher than 0%. + run: go test -v -coverprofile=module-coverage.txt -coverpkg=./... ./... + - name: Run tests (32 bit) + if: startsWith(matrix.os, 'macos') == false # can't run 32 bit tests on OSX. + uses: protocol/multiple-go-modules@v1.2 + env: + GOARCH: 386 + with: + run: | + export "PATH=${{ env.PATH_386 }}:$PATH" + go test -v ./... + - name: Run tests with race detector + if: startsWith(matrix.os, 'ubuntu') # speed things up. Windows and OSX VMs are slow + uses: protocol/multiple-go-modules@v1.2 + with: + run: go test -v -race ./... + - name: Collect coverage files + shell: bash + run: echo "COVERAGES=$(find . -type f -name 'module-coverage.txt' | tr -s '\n' ',' | sed 's/,$//')" >> $GITHUB_ENV + - name: Upload coverage to Codecov + uses: codecov/codecov-action@f32b3a3741e1053eb607407145bc9619351dc93b # v2.1.0 + with: + files: '${{ env.COVERAGES }}' + env_vars: OS=${{ matrix.os }}, GO=${{ matrix.go }} diff --git a/.github/workflows/go-test.yml b/.github/workflows/go-test.yml new file mode 100644 index 00000000..8a1697b2 --- /dev/null +++ b/.github/workflows/go-test.yml @@ -0,0 +1,68 @@ +# File managed by web3-bot. DO NOT EDIT. +# See https://github.com/protocol/.github/ for details. + +on: [push, pull_request] +name: Go Test + +jobs: + unit: + strategy: + fail-fast: false + matrix: + os: [ "ubuntu", "windows", "macos" ] + go: [ "1.18.x", "1.19.x" ] + env: + COVERAGES: "" + runs-on: ${{ format('{0}-latest', matrix.os) }} + name: ${{ matrix.os }} (go ${{ matrix.go }}) + steps: + - uses: actions/checkout@v3 + with: + submodules: recursive + - uses: actions/setup-go@v3 + with: + go-version: ${{ matrix.go }} + - name: Go information + run: | + go version + go env + - name: Use msys2 on windows + if: ${{ matrix.os == 'windows' }} + shell: bash + # The executable for msys2 is also called bash.cmd + # https://github.com/actions/virtual-environments/blob/main/images/win/Windows2019-Readme.md#shells + # If we prepend its location to the PATH + # subsequent 'shell: bash' steps will use msys2 instead of gitbash + run: echo "C:/msys64/usr/bin" >> $GITHUB_PATH + - name: Run repo-specific setup + uses: ./.github/actions/go-test-setup + if: hashFiles('./.github/actions/go-test-setup') != '' + - name: Run tests + uses: protocol/multiple-go-modules@v1.2 + with: + # Use -coverpkg=./..., so that we include cross-package coverage. + # If package ./A imports ./B, and ./A's tests also cover ./B, + # this means ./B's coverage will be significantly higher than 0%. + run: go test -v -shuffle=on -coverprofile=module-coverage.txt -coverpkg=./... ./... + - name: Run tests (32 bit) + if: ${{ matrix.os != 'macos' }} # can't run 32 bit tests on OSX. + uses: protocol/multiple-go-modules@v1.2 + env: + GOARCH: 386 + with: + run: | + export "PATH=${{ env.PATH_386 }}:$PATH" + go test -v -shuffle=on ./... + - name: Run tests with race detector + if: ${{ matrix.os == 'ubuntu' }} # speed things up. Windows and OSX VMs are slow + uses: protocol/multiple-go-modules@v1.2 + with: + run: go test -v -race ./... + - name: Collect coverage files + shell: bash + run: echo "COVERAGES=$(find . -type f -name 'module-coverage.txt' | tr -s '\n' ',' | sed 's/,$//')" >> $GITHUB_ENV + - name: Upload coverage to Codecov + uses: codecov/codecov-action@81cd2dc8148241f03f5839d295e000b8f761e378 # v3.1.0 + with: + files: '${{ env.COVERAGES }}' + env_vars: OS=${{ matrix.os }}, GO=${{ matrix.go }} diff --git a/.github/workflows/release-check.yml b/.github/workflows/release-check.yml new file mode 100644 index 00000000..fde81c1f --- /dev/null +++ b/.github/workflows/release-check.yml @@ -0,0 +1,11 @@ +# File managed by web3-bot. DO NOT EDIT. +# See https://github.com/protocol/.github/ for details. + +name: Release Checker +on: + pull_request: + paths: [ 'version.json' ] + +jobs: + release-check: + uses: protocol/.github/.github/workflows/release-check.yml@master diff --git a/.github/workflows/releaser.yml b/.github/workflows/releaser.yml new file mode 100644 index 00000000..cdccbf87 --- /dev/null +++ b/.github/workflows/releaser.yml @@ -0,0 +1,11 @@ +# File managed by web3-bot. DO NOT EDIT. +# See https://github.com/protocol/.github/ for details. + +name: Releaser +on: + push: + paths: [ 'version.json' ] + +jobs: + releaser: + uses: protocol/.github/.github/workflows/releaser.yml@master diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml new file mode 100644 index 00000000..6f6d895d --- /dev/null +++ b/.github/workflows/stale.yml @@ -0,0 +1,26 @@ +name: Close and mark stale issue + +on: + schedule: + - cron: '0 0 * * *' + +jobs: + stale: + + runs-on: ubuntu-latest + permissions: + issues: write + pull-requests: write + + steps: + - uses: actions/stale@v3 + with: + repo-token: ${{ secrets.GITHUB_TOKEN }} + stale-issue-message: 'Oops, seems like we needed more information for this issue, please comment with more details or this issue will be closed in 7 days.' + close-issue-message: 'This issue was closed because it is missing author input.' + stale-issue-label: 'kind/stale' + any-of-labels: 'need/author-input' + exempt-issue-labels: 'need/triage,need/community-input,need/maintainer-input,need/maintainers-input,need/analysis,status/blocked,status/in-progress,status/ready,status/deferred,status/inactive' + days-before-issue-stale: 6 + days-before-issue-close: 7 + enable-statistics: true diff --git a/.github/workflows/tagpush.yml b/.github/workflows/tagpush.yml new file mode 100644 index 00000000..d8499618 --- /dev/null +++ b/.github/workflows/tagpush.yml @@ -0,0 +1,12 @@ +# File managed by web3-bot. DO NOT EDIT. +# See https://github.com/protocol/.github/ for details. + +name: Tag Push Checker +on: + push: + tags: + - v* + +jobs: + releaser: + uses: protocol/.github/.github/workflows/tagpush.yml@master diff --git a/README.md b/README.md index 854df05a..62ac7dcd 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,16 @@ # OpenSSL bindings for Go -Please see http://godoc.org/github.com/spacemonkeygo/openssl for more info +Forked from https://github.com/spacemonkeygo/openssl (unmaintained) to add: + +1. FreeBSD support. +2. Key equality checking. +3. A function to get the size of signatures produced by a key. + +--- + +Please see http://godoc.org/github.com/libp2p/go-openssl for more info + +--- ### License diff --git a/alloc.go b/alloc.go new file mode 100644 index 00000000..25d064a2 --- /dev/null +++ b/alloc.go @@ -0,0 +1,19 @@ +package openssl + +// #include "shim.h" +import "C" + +import ( + "unsafe" + + "github.com/mattn/go-pointer" +) + +//export go_ssl_crypto_ex_free +func go_ssl_crypto_ex_free( + parent *C.void, ptr unsafe.Pointer, + cryptoData *C.CRYPTO_EX_DATA, idx C.int, + argl C.long, argp *C.void, +) { + pointer.Unref(ptr) +} diff --git a/bio.go b/bio.go index 9fe32aa8..caf2b37a 100644 --- a/bio.go +++ b/bio.go @@ -112,14 +112,14 @@ func writeBioPending(b *C.BIO) C.long { return C.long(len(ptr.buf)) } -func (b *writeBio) WriteTo(w io.Writer) (rv int64, err error) { - b.op_mtx.Lock() - defer b.op_mtx.Unlock() +func (wb *writeBio) WriteTo(w io.Writer) (rv int64, err error) { + wb.op_mtx.Lock() + defer wb.op_mtx.Unlock() // write whatever data we currently have - b.data_mtx.Lock() - data := b.buf - b.data_mtx.Unlock() + wb.data_mtx.Lock() + data := wb.buf + wb.data_mtx.Unlock() if len(data) == 0 { return 0, nil @@ -127,26 +127,26 @@ func (b *writeBio) WriteTo(w io.Writer) (rv int64, err error) { n, err := w.Write(data) // subtract however much data we wrote from the buffer - b.data_mtx.Lock() - b.buf = b.buf[:copy(b.buf, b.buf[n:])] - if b.release_buffers && len(b.buf) == 0 { - b.buf = nil + wb.data_mtx.Lock() + wb.buf = wb.buf[:copy(wb.buf, wb.buf[n:])] + if wb.release_buffers && len(wb.buf) == 0 { + wb.buf = nil } - b.data_mtx.Unlock() + wb.data_mtx.Unlock() return int64(n), err } -func (self *writeBio) Disconnect(b *C.BIO) { - if loadWritePtr(b) == self { +func (wb *writeBio) Disconnect(b *C.BIO) { + if loadWritePtr(b) == wb { writeBioMapping.Del(token(C.X_BIO_get_data(b))) C.X_BIO_set_data(b, nil) } } -func (b *writeBio) MakeCBIO() *C.BIO { +func (wb *writeBio) MakeCBIO() *C.BIO { rv := C.X_BIO_new_write_bio() - token := writeBioMapping.Add(unsafe.Pointer(b)) + token := writeBioMapping.Add(unsafe.Pointer(wb)) C.X_BIO_set_data(rv, unsafe.Pointer(token)) return rv } @@ -228,53 +228,53 @@ func readBioPending(b *C.BIO) C.long { return C.long(len(ptr.buf)) } -func (b *readBio) ReadFromOnce(r io.Reader) (n int, err error) { - b.op_mtx.Lock() - defer b.op_mtx.Unlock() +func (rb *readBio) ReadFromOnce(r io.Reader) (n int, err error) { + rb.op_mtx.Lock() + defer rb.op_mtx.Unlock() // make sure we have a destination that fits at least one SSL record - b.data_mtx.Lock() - if cap(b.buf) < len(b.buf)+SSLRecordSize { - new_buf := make([]byte, len(b.buf), len(b.buf)+SSLRecordSize) - copy(new_buf, b.buf) - b.buf = new_buf + rb.data_mtx.Lock() + if cap(rb.buf) < len(rb.buf)+SSLRecordSize { + new_buf := make([]byte, len(rb.buf), len(rb.buf)+SSLRecordSize) + copy(new_buf, rb.buf) + rb.buf = new_buf } - dst := b.buf[len(b.buf):cap(b.buf)] - dst_slice := b.buf - b.data_mtx.Unlock() + dst := rb.buf[len(rb.buf):cap(rb.buf)] + dst_slice := rb.buf + rb.data_mtx.Unlock() n, err = r.Read(dst) - b.data_mtx.Lock() - defer b.data_mtx.Unlock() + rb.data_mtx.Lock() + defer rb.data_mtx.Unlock() if n > 0 { - if len(dst_slice) != len(b.buf) { + if len(dst_slice) != len(rb.buf) { // someone shrunk the buffer, so we read in too far ahead and we // need to slide backwards - copy(b.buf[len(b.buf):len(b.buf)+n], dst) + copy(rb.buf[len(rb.buf):len(rb.buf)+n], dst) } - b.buf = b.buf[:len(b.buf)+n] + rb.buf = rb.buf[:len(rb.buf)+n] } return n, err } -func (b *readBio) MakeCBIO() *C.BIO { +func (rb *readBio) MakeCBIO() *C.BIO { rv := C.X_BIO_new_read_bio() - token := readBioMapping.Add(unsafe.Pointer(b)) + token := readBioMapping.Add(unsafe.Pointer(rb)) C.X_BIO_set_data(rv, unsafe.Pointer(token)) return rv } -func (self *readBio) Disconnect(b *C.BIO) { - if loadReadPtr(b) == self { +func (rb *readBio) Disconnect(b *C.BIO) { + if loadReadPtr(b) == rb { readBioMapping.Del(token(C.X_BIO_get_data(b))) C.X_BIO_set_data(b, nil) } } -func (b *readBio) MarkEOF() { - b.data_mtx.Lock() - defer b.data_mtx.Unlock() - b.eof = true +func (rb *readBio) MarkEOF() { + rb.data_mtx.Lock() + defer rb.data_mtx.Unlock() + rb.eof = true } type anyBio C.BIO diff --git a/build.go b/build.go index 5fccc021..f85aec98 100644 --- a/build.go +++ b/build.go @@ -12,13 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build !openssl_static +//go:build !openssl_static package openssl -// #cgo linux windows pkg-config: libssl libcrypto -// #cgo linux CFLAGS: -Wno-deprecated-declarations -// #cgo darwin CFLAGS: -I/usr/local/opt/openssl@1.1/include -I/usr/local/opt/openssl/include -Wno-deprecated-declarations -// #cgo darwin LDFLAGS: -L/usr/local/opt/openssl@1.1/lib -L/usr/local/opt/openssl/lib -lssl -lcrypto +// #cgo linux windows freebsd openbsd solaris pkg-config: libssl libcrypto +// #cgo linux freebsd openbsd solaris CFLAGS: -Wno-deprecated-declarations +// #cgo darwin 386 CFLAGS: -I/usr/local/opt/openssl@1.1/include -I/usr/local/opt/openssl/include -Wno-deprecated-declarations +// #cgo darwin 386 LDFLAGS: -L/usr/local/opt/openssl@1.1/lib -L/usr/local/opt/openssl/lib -lssl -lcrypto +// #cgo darwin arm64 CFLAGS: -I/opt/homebrew/opt/openssl@1.1/include -Wno-deprecated-declarations +// #cgo darwin arm64 LDFLAGS: -L/opt/homebrew/opt/openssl@1.1/lib -lssl -lcrypto // #cgo windows CFLAGS: -DWIN32_LEAN_AND_MEAN import "C" diff --git a/build_static.go b/build_static.go index c84427bc..f2c87cc5 100644 --- a/build_static.go +++ b/build_static.go @@ -12,13 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build openssl_static +//go:build openssl_static package openssl -// #cgo linux windows pkg-config: --static libssl libcrypto -// #cgo linux CFLAGS: -Wno-deprecated-declarations -// #cgo darwin CFLAGS: -I/usr/local/opt/openssl@1.1/include -I/usr/local/opt/openssl/include -Wno-deprecated-declarations -// #cgo darwin LDFLAGS: -L/usr/local/opt/openssl@1.1/lib -L/usr/local/opt/openssl/lib -lssl -lcrypto +// #cgo linux windows freebsd openbsd solaris pkg-config: --static libssl libcrypto +// #cgo linux freebsd openbsd solaris CFLAGS: -Wno-deprecated-declarations +// #cgo darwin 386 CFLAGS: -I/usr/local/opt/openssl@1.1/include -I/usr/local/opt/openssl/include -Wno-deprecated-declarations +// #cgo darwin 386 LDFLAGS: -L/usr/local/opt/openssl@1.1/lib -L/usr/local/opt/openssl/lib -lssl -lcrypto +// #cgo darwin arm64 CFLAGS: -I/opt/homebrew/opt/openssl@1.1/include -Wno-deprecated-declarations +// #cgo darwin arm64 LDFLAGS: -L/opt/homebrew/opt/openssl@1.1/lib -lssl -lcrypto // #cgo windows CFLAGS: -DWIN32_LEAN_AND_MEAN import "C" diff --git a/cert.go b/cert.go index e841e22c..fb482c0b 100644 --- a/cert.go +++ b/cert.go @@ -19,7 +19,7 @@ import "C" import ( "errors" - "io/ioutil" + "io" "math/big" "runtime" "time" @@ -267,8 +267,8 @@ func (c *Certificate) Sign(privKey PrivateKey, digest EVP_MD) error { case EVP_SHA384: case EVP_SHA512: default: - return errors.New("Unsupported digest" + - "You're probably looking for 'EVP_SHA256' or 'EVP_SHA512'.") + return errors.New("unsupported digest; " + + "you're probably looking for 'EVP_SHA256' or 'EVP_SHA512'") } return c.insecureSign(privKey, digest) } @@ -331,6 +331,16 @@ func (c *Certificate) AddExtension(nid NID, value string) error { return nil } +// AddCustomExtension add custom extenstion to the certificate. +func (c *Certificate) AddCustomExtension(nid NID, value []byte) error { + val := (*C.char)(C.CBytes(value)) + defer C.free(unsafe.Pointer(val)) + if int(C.add_custom_ext(c.x, C.int(nid), val, C.int(len(value)))) == 0 { + return errors.New("unable to add extension") + } + return nil +} + // Wraps AddExtension using a map of NID to text extension. // Will return without finishing if it encounters an error. func (c *Certificate) AddExtensions(extensions map[NID]string) error { @@ -373,7 +383,7 @@ func (c *Certificate) MarshalPEM() (pem_block []byte, err error) { if int(C.PEM_write_bio_X509(bio, c.x)) != 1 { return nil, errors.New("failed dumping certificate") } - return ioutil.ReadAll(asAnyBio(bio)) + return io.ReadAll(asAnyBio(bio)) } // PublicKey returns the public key embedded in the X509 certificate. @@ -413,3 +423,10 @@ func (c *Certificate) SetVersion(version X509_Version) error { } return nil } + +// GetExtensionValue returns the value of the given NID's extension. +func (c *Certificate) GetExtensionValue(nid NID) []byte { + dataLength := C.int(0) + val := C.get_extention(c.x, C.int(nid), &dataLength) + return C.GoBytes(unsafe.Pointer(val), dataLength) +} diff --git a/ciphers.go b/ciphers.go index 509bf641..a3a597c4 100644 --- a/ciphers.go +++ b/ciphers.go @@ -125,7 +125,7 @@ func (ctx *cipherCtx) applyKeyAndIV(key, iv []byte) error { } else { res = C.EVP_DecryptInit_ex(ctx.ctx, nil, nil, kptr, iptr) } - if 1 != res { + if res != 1 { return errors.New("failed to apply key/IV") } } @@ -243,7 +243,7 @@ func newEncryptionCipherCtx(c *Cipher, e *Engine, key, iv []byte) ( if e != nil { eptr = e.e } - if 1 != C.EVP_EncryptInit_ex(ctx.ctx, c.ptr, eptr, nil, nil) { + if C.EVP_EncryptInit_ex(ctx.ctx, c.ptr, eptr, nil, nil) != 1 { return nil, errors.New("failed to initialize cipher context") } err = ctx.applyKeyAndIV(key, iv) @@ -266,7 +266,7 @@ func newDecryptionCipherCtx(c *Cipher, e *Engine, key, iv []byte) ( if e != nil { eptr = e.e } - if 1 != C.EVP_DecryptInit_ex(ctx.ctx, c.ptr, eptr, nil, nil) { + if C.EVP_DecryptInit_ex(ctx.ctx, c.ptr, eptr, nil, nil) != 1 { return nil, errors.New("failed to initialize cipher context") } err = ctx.applyKeyAndIV(key, iv) @@ -317,7 +317,7 @@ func (ctx *decryptionCipherCtx) DecryptUpdate(input []byte) ([]byte, error) { func (ctx *encryptionCipherCtx) EncryptFinal() ([]byte, error) { outbuf := make([]byte, ctx.BlockSize()) var outlen C.int - if 1 != C.EVP_EncryptFinal_ex(ctx.ctx, (*C.uchar)(&outbuf[0]), &outlen) { + if C.EVP_EncryptFinal_ex(ctx.ctx, (*C.uchar)(&outbuf[0]), &outlen) != 1 { return nil, errors.New("encryption failed") } return outbuf[:outlen], nil @@ -326,7 +326,7 @@ func (ctx *encryptionCipherCtx) EncryptFinal() ([]byte, error) { func (ctx *decryptionCipherCtx) DecryptFinal() ([]byte, error) { outbuf := make([]byte, ctx.BlockSize()) var outlen C.int - if 1 != C.EVP_DecryptFinal_ex(ctx.ctx, (*C.uchar)(&outbuf[0]), &outlen) { + if C.EVP_DecryptFinal_ex(ctx.ctx, (*C.uchar)(&outbuf[0]), &outlen) != 1 { // this may mean the tag failed to verify- all previous plaintext // returned must be considered faked and invalid return nil, errors.New("decryption failed") diff --git a/ciphers_gcm.go b/ciphers_gcm.go index 7b08e0fd..06ba0fed 100644 --- a/ciphers_gcm.go +++ b/ciphers_gcm.go @@ -86,8 +86,8 @@ func NewGCMEncryptionCipherCtx(blocksize int, e *Engine, key, iv []byte) ( return nil, fmt.Errorf("could not set IV len to %d: %s", len(iv), err) } - if 1 != C.EVP_EncryptInit_ex(ctx.ctx, nil, nil, nil, - (*C.uchar)(&iv[0])) { + if C.EVP_EncryptInit_ex(ctx.ctx, nil, nil, nil, + (*C.uchar)(&iv[0])) != 1 { return nil, errors.New("failed to apply IV") } } @@ -110,8 +110,8 @@ func NewGCMDecryptionCipherCtx(blocksize int, e *Engine, key, iv []byte) ( return nil, fmt.Errorf("could not set IV len to %d: %s", len(iv), err) } - if 1 != C.EVP_DecryptInit_ex(ctx.ctx, nil, nil, nil, - (*C.uchar)(&iv[0])) { + if C.EVP_DecryptInit_ex(ctx.ctx, nil, nil, nil, + (*C.uchar)(&iv[0])) != 1 { return nil, errors.New("failed to apply IV") } } @@ -123,8 +123,8 @@ func (ctx *authEncryptionCipherCtx) ExtraData(aad []byte) error { return nil } var outlen C.int - if 1 != C.EVP_EncryptUpdate(ctx.ctx, nil, &outlen, (*C.uchar)(&aad[0]), - C.int(len(aad))) { + if C.EVP_EncryptUpdate(ctx.ctx, nil, &outlen, (*C.uchar)(&aad[0]), + C.int(len(aad))) != 1 { return errors.New("failed to add additional authenticated data") } return nil @@ -135,8 +135,8 @@ func (ctx *authDecryptionCipherCtx) ExtraData(aad []byte) error { return nil } var outlen C.int - if 1 != C.EVP_DecryptUpdate(ctx.ctx, nil, &outlen, (*C.uchar)(&aad[0]), - C.int(len(aad))) { + if C.EVP_DecryptUpdate(ctx.ctx, nil, &outlen, (*C.uchar)(&aad[0]), + C.int(len(aad))) != 1 { return errors.New("failed to add additional authenticated data") } return nil diff --git a/ciphers_test.go b/ciphers_test.go index fe991ab4..0f1e3404 100644 --- a/ciphers_test.go +++ b/ciphers_test.go @@ -185,17 +185,16 @@ func TestBadTag(t *testing.T) { } // flip the last bit tag[len(tag)-1] ^= 1 - plaintext_out, err := doDecryption(key, iv, nil, ciphertext, tag, 128, 129) - if err == nil { + if _, err := doDecryption(key, iv, nil, ciphertext, tag, 128, 129); err == nil { t.Fatal("Expected error for bad tag, but got none") } // flip it back, try again just to make sure tag[len(tag)-1] ^= 1 - plaintext_out, err = doDecryption(key, iv, nil, ciphertext, tag, 128, 129) + plaintextOut, err := doDecryption(key, iv, nil, ciphertext, tag, 128, 129) if err != nil { t.Fatal("Decryption failure:", err) } - checkEqual(t, plaintext_out, plaintext) + checkEqual(t, plaintextOut, plaintext) } func TestBadCiphertext(t *testing.T) { @@ -211,17 +210,16 @@ func TestBadCiphertext(t *testing.T) { } // flip the last bit ciphertext[len(ciphertext)-1] ^= 1 - plaintext_out, err := doDecryption(key, iv, aad, ciphertext, tag, 192, 192) - if err == nil { + if _, err := doDecryption(key, iv, aad, ciphertext, tag, 192, 192); err == nil { t.Fatal("Expected error for bad ciphertext, but got none") } // flip it back, try again just to make sure ciphertext[len(ciphertext)-1] ^= 1 - plaintext_out, err = doDecryption(key, iv, aad, ciphertext, tag, 192, 192) + plaintextOut, err := doDecryption(key, iv, aad, ciphertext, tag, 192, 192) if err != nil { t.Fatal("Decryption failure:", err) } - checkEqual(t, plaintext_out, plaintext) + checkEqual(t, plaintextOut, plaintext) } func TestBadAAD(t *testing.T) { @@ -237,17 +235,16 @@ func TestBadAAD(t *testing.T) { } // flip the last bit aad[len(aad)-1] ^= 1 - plaintext_out, err := doDecryption(key, iv, aad, ciphertext, tag, 256, 256) - if err == nil { + if _, err := doDecryption(key, iv, aad, ciphertext, tag, 256, 256); err == nil { t.Fatal("Expected error for bad AAD, but got none") } // flip it back, try again just to make sure aad[len(aad)-1] ^= 1 - plaintext_out, err = doDecryption(key, iv, aad, ciphertext, tag, 256, 256) + plaintextOut, err := doDecryption(key, iv, aad, ciphertext, tag, 256, 256) if err != nil { t.Fatal("Decryption failure:", err) } - checkEqual(t, plaintext_out, plaintext) + checkEqual(t, plaintextOut, plaintext) } func TestNonAuthenticatedEncryption(t *testing.T) { diff --git a/conn.go b/conn.go index 964551cd..fc9421ff 100644 --- a/conn.go +++ b/conn.go @@ -27,14 +27,15 @@ import ( "time" "unsafe" - "github.com/spacemonkeygo/openssl/utils" + "github.com/libp2p/go-openssl/utils" + "github.com/mattn/go-pointer" ) var ( - zeroReturn = errors.New("zero return") - wantRead = errors.New("want read") - wantWrite = errors.New("want write") - tryAgain = errors.New("try again") + errZeroReturn = errors.New("zero return") + errWantRead = errors.New("want read") + errWantWrite = errors.New("want write") + errTryAgain = errors.New("try again") ) type Conn struct { @@ -137,7 +138,7 @@ func newConn(conn net.Conn, ctx *Ctx) (*Conn, error) { C.SSL_set_bio(ssl, into_ssl_cbio, from_ssl_cbio) s := &SSL{ssl: ssl} - C.SSL_set_ex_data(s.ssl, get_ssl_idx(), unsafe.Pointer(s)) + C.SSL_set_ex_data(s.ssl, get_ssl_idx(), pointer.Save(s)) c := &Conn{ SSL: s, @@ -192,7 +193,7 @@ func (c *Conn) GetCtx() *Ctx { return c.ctx } func (c *Conn) CurrentCipher() (string, error) { p := C.X_SSL_get_cipher_name(c.ssl) if p == nil { - return "", errors.New("Session not established") + return "", errors.New("session not established") } return C.GoString(p), nil @@ -247,7 +248,7 @@ func (c *Conn) getErrorHandler(rv C.int, errno error) func() error { if err != nil { return err } - return tryAgain + return errTryAgain } case C.SSL_ERROR_WANT_WRITE: return func() error { @@ -255,7 +256,7 @@ func (c *Conn) getErrorHandler(rv C.int, errno error) func() error { if err != nil { return err } - return tryAgain + return errTryAgain } case C.SSL_ERROR_SYSCALL: var err error @@ -303,8 +304,8 @@ func (c *Conn) handshake() func() error { // Handshake performs an SSL handshake. If a handshake is not manually // triggered, it will run before the first I/O on the encrypted stream. func (c *Conn) Handshake() error { - err := tryAgain - for err == tryAgain { + err := errTryAgain + for err == errTryAgain { err = c.handleError(c.handshake()) } go c.flushOutputBuffer() @@ -404,15 +405,15 @@ func (c *Conn) shutdown() func() error { } func (c *Conn) shutdownLoop() error { - err := tryAgain + err := errTryAgain shutdown_tries := 0 - for err == tryAgain { + for err == errTryAgain { shutdown_tries = shutdown_tries + 1 err = c.handleError(c.shutdown()) if err == nil { return c.flushOutputBuffer() } - if err == tryAgain && shutdown_tries >= 2 { + if err == errTryAgain && shutdown_tries >= 2 { return errors.New("shutdown requested a third time?") } } @@ -463,8 +464,8 @@ func (c *Conn) Read(b []byte) (n int, err error) { if len(b) == 0 { return 0, nil } - err = tryAgain - for err == tryAgain { + err = errTryAgain + for err == errTryAgain { n, errcb := c.read(b) err = c.handleError(errcb) if err == nil { @@ -504,8 +505,8 @@ func (c *Conn) Write(b []byte) (written int, err error) { if len(b) == 0 { return 0, nil } - err = tryAgain - for err == tryAgain { + err = errTryAgain + for err == errTryAgain { n, errcb := c.write(b) err = c.handleError(errcb) if err == nil { diff --git a/ctx.go b/ctx.go index 33befc40..7b624d90 100644 --- a/ctx.go +++ b/ctx.go @@ -20,13 +20,13 @@ import "C" import ( "errors" "fmt" - "io/ioutil" "os" "runtime" "sync" "time" "unsafe" + "github.com/mattn/go-pointer" "github.com/spacemonkeygo/spacelog" ) @@ -61,7 +61,7 @@ func newCtx(method *C.SSL_METHOD) (*Ctx, error) { return nil, errorFromErrorQueue() } c := &Ctx{ctx: ctx} - C.SSL_CTX_set_ex_data(ctx, get_ssl_ctx_idx(), unsafe.Pointer(c)) + C.SSL_CTX_set_ex_data(ctx, get_ssl_ctx_idx(), pointer.Save(c)) runtime.SetFinalizer(c, func(c *Ctx) { C.SSL_CTX_free(c.ctx) }) @@ -120,14 +120,14 @@ func NewCtxFromFiles(cert_file string, key_file string) (*Ctx, error) { return nil, err } - cert_bytes, err := ioutil.ReadFile(cert_file) + cert_bytes, err := os.ReadFile(cert_file) if err != nil { return nil, err } certs := SplitPEM(cert_bytes) if len(certs) == 0 { - return nil, fmt.Errorf("No PEM certificate found in '%s'", cert_file) + return nil, fmt.Errorf("no PEM certificate found in '%s'", cert_file) } first, certs := certs[0], certs[1:] cert, err := LoadCertificateFromPEM(first) @@ -151,7 +151,7 @@ func NewCtxFromFiles(cert_file string, key_file string) (*Ctx, error) { } } - key_bytes, err := ioutil.ReadFile(key_file) + key_bytes, err := os.ReadFile(key_file) if err != nil { return nil, err } @@ -190,7 +190,7 @@ func (c *Ctx) SetEllipticCurve(curve EllipticCurve) error { k := C.EC_KEY_new_by_curve_name(C.int(curve)) if k == nil { - return errors.New("Unknown curve") + return errors.New("unknown curve") } defer C.EC_KEY_free(k) @@ -302,12 +302,12 @@ type CertificateStoreCtx struct { ssl_ctx *Ctx } -func (self *CertificateStoreCtx) VerifyResult() VerifyResult { - return VerifyResult(C.X509_STORE_CTX_get_error(self.ctx)) +func (csc *CertificateStoreCtx) VerifyResult() VerifyResult { + return VerifyResult(C.X509_STORE_CTX_get_error(csc.ctx)) } -func (self *CertificateStoreCtx) Err() error { - code := C.X509_STORE_CTX_get_error(self.ctx) +func (csc *CertificateStoreCtx) Err() error { + code := C.X509_STORE_CTX_get_error(csc.ctx) if code == C.X509_V_OK { return nil } @@ -315,19 +315,19 @@ func (self *CertificateStoreCtx) Err() error { C.GoString(C.X509_verify_cert_error_string(C.long(code)))) } -func (self *CertificateStoreCtx) Depth() int { - return int(C.X509_STORE_CTX_get_error_depth(self.ctx)) +func (csc *CertificateStoreCtx) Depth() int { + return int(C.X509_STORE_CTX_get_error_depth(csc.ctx)) } -// the certicate returned is only valid for the lifetime of the underlying +// the certificate returned is only valid for the lifetime of the underlying // X509_STORE_CTX -func (self *CertificateStoreCtx) GetCurrentCert() *Certificate { - x509 := C.X509_STORE_CTX_get_current_cert(self.ctx) +func (csc *CertificateStoreCtx) GetCurrentCert() *Certificate { + x509 := C.X509_STORE_CTX_get_current_cert(csc.ctx) if x509 == nil { return nil } // add a ref - if 1 != C.X_X509_add_ref(x509) { + if C.X_X509_add_ref(x509) != 1 { return nil } cert := &Certificate{ @@ -361,6 +361,32 @@ func (c *Ctx) LoadVerifyLocations(ca_file string, ca_path string) error { return nil } +type Version int + +const ( + SSL3_VERSION Version = C.SSL3_VERSION + TLS1_VERSION Version = C.TLS1_VERSION + TLS1_1_VERSION Version = C.TLS1_1_VERSION + TLS1_2_VERSION Version = C.TLS1_2_VERSION + TLS1_3_VERSION Version = C.TLS1_3_VERSION + DTLS1_VERSION Version = C.DTLS1_VERSION + DTLS1_2_VERSION Version = C.DTLS1_2_VERSION +) + +// SetMinProtoVersion sets the minimum supported protocol version for the Ctx. +// http://www.openssl.org/docs/ssl/SSL_CTX_set_min_proto_version.html +func (c *Ctx) SetMinProtoVersion(version Version) bool { + return C.X_SSL_CTX_set_min_proto_version( + c.ctx, C.int(version)) == 1 +} + +// SetMaxProtoVersion sets the maximum supported protocol version for the Ctx. +// http://www.openssl.org/docs/ssl/SSL_CTX_set_max_proto_version.html +func (c *Ctx) SetMaxProtoVersion(version Version) bool { + return C.X_SSL_CTX_set_max_proto_version( + c.ctx, C.int(version)) == 1 +} + type Options int const ( @@ -430,7 +456,7 @@ func go_ssl_ctx_verify_cb_thunk(p unsafe.Pointer, ok C.int, ctx *C.X509_STORE_CT os.Exit(1) } }() - verify_cb := (*Ctx)(p).verify_cb + verify_cb := pointer.Restore(p).(*Ctx).verify_cb // set up defaults just in case verify_cb is nil if verify_cb != nil { store := &CertificateStoreCtx{ctx: ctx} @@ -522,6 +548,29 @@ func (c *Ctx) SetCipherList(list string) error { return nil } +// SetNextProtos sets Negotiation protocol to the ctx. +func (c *Ctx) SetNextProtos(protos []string) error { + if len(protos) == 0 { + return nil + } + vector := make([]byte, 0) + for _, proto := range protos { + if len(proto) > 255 { + return fmt.Errorf( + "proto length can't be more than 255. But got a proto %s with length %d", + proto, len(proto)) + } + vector = append(vector, byte(uint8(len(proto)))) + vector = append(vector, []byte(proto)...) + } + ret := int(C.SSL_CTX_set_alpn_protos(c.ctx, (*C.uchar)(unsafe.Pointer(&vector[0])), + C.uint(len(vector)))) + if ret != 0 { + return errors.New("error while setting protos to ctx") + } + return nil +} + type SessionCacheModes int const ( diff --git a/dh_test.go b/dh_test.go index fbe3e356..811020b7 100644 --- a/dh_test.go +++ b/dh_test.go @@ -40,7 +40,7 @@ func TestECDH(t *testing.T) { t.Fatal(err) } - if bytes.Compare(mySecret, theirSecret) != 0 { + if !bytes.Equal(mySecret, theirSecret) { t.Fatal("shared secrets are different") } } diff --git a/extension.c b/extension.c new file mode 100644 index 00000000..99f1ca3d --- /dev/null +++ b/extension.c @@ -0,0 +1,40 @@ + + +#include +#include + +const unsigned char * get_extention(X509 *x, int NID, int *data_len){ + int loc; + ASN1_OCTET_STRING *octet_str; + long xlen; + int tag, xclass; + + loc = X509_get_ext_by_NID( x, NID, -1); + X509_EXTENSION *ex = X509_get_ext(x, loc); + octet_str = X509_EXTENSION_get_data(ex); + *data_len = octet_str->length; + return octet_str->data; +} + +// Copied from https://github.com/libtor/openssl/blob/master/demos/x509/mkcert.c#L153 +int add_custom_ext(X509 *cert, int nid,unsigned char *value, int len) +{ + X509_EXTENSION *ex; + ASN1_OCTET_STRING *os = ASN1_OCTET_STRING_new(); + ASN1_OCTET_STRING_set(os,value,len); + X509V3_CTX ctx; + /* This sets the 'context' of the extensions. */ + /* No configuration database */ + X509V3_set_ctx_nodb(&ctx); + /* Issuer and subject certs: both the target since it is self signed, + * no request and no CRL + */ + X509V3_set_ctx(&ctx, cert, cert, NULL, NULL, 0); + // ref http://openssl.6102.n7.nabble.com/Adding-a-custom-extension-to-a-CSR-td47446.html + ex = X509_EXTENSION_create_by_NID( NULL, nid, 0, os); + if (!X509_add_ext(cert,ex,-1)) + return 0; + + X509_EXTENSION_free(ex); + return 1; +} \ No newline at end of file diff --git a/fips.go b/fips.go index f65e14d3..e187ebb5 100644 --- a/fips.go +++ b/fips.go @@ -16,16 +16,31 @@ package openssl /* #include + +#if OPENSSL_VERSION_NUMBER >= 0x30000000L + int FIPS_mode_set(int ONOFF) { + return 0; + } +#endif + */ import "C" -import "runtime" +import ( + "errors" + "runtime" +) // FIPSModeSet enables a FIPS 140-2 validated mode of operation. // https://wiki.openssl.org/index.php/FIPS_mode_set() +// This call has been deleted from OpenSSL 3.0. func FIPSModeSet(mode bool) error { runtime.LockOSThread() defer runtime.UnlockOSThread() + if C.OPENSSL_VERSION_NUMBER >= 0x30000000 { + return errors.New("FIPS_mode_set() has been deleted from OpenSSL 3.0") + } + var r C.int if mode { r = C.FIPS_mode_set(1) diff --git a/go.mod b/go.mod index 73f3bbfe..b2f36ea9 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,10 @@ -module github.com/spacemonkeygo/openssl +module github.com/libp2p/go-openssl -require github.com/spacemonkeygo/spacelog v0.0.0-20180420211403-2296661a0572 +require ( + github.com/mattn/go-pointer v0.0.1 + github.com/spacemonkeygo/spacelog v0.0.0-20180420211403-2296661a0572 +) + +require golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb // indirect + +go 1.18 diff --git a/go.sum b/go.sum index 1b0ecc56..0e9b14d8 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,6 @@ +github.com/mattn/go-pointer v0.0.1 h1:n+XhsuGeVO6MEAp7xyEukFINEa+Quek5psIR/ylA6o0= +github.com/mattn/go-pointer v0.0.1/go.mod h1:2zXcozF6qYGgmsG+SeTZz3oAbFLdD3OWqnUbNvJZAlc= github.com/spacemonkeygo/spacelog v0.0.0-20180420211403-2296661a0572 h1:RC6RW7j+1+HkWaX/Yh71Ee5ZHaHYt7ZP4sQgUrm6cDU= github.com/spacemonkeygo/spacelog v0.0.0-20180420211403-2296661a0572/go.mod h1:w0SWMsp6j9O/dk4/ZpIhL+3CkG8ofA2vuv7k+ltqUMc= +golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb h1:fgwFCsaw9buMuxNd6+DQfAuSFqbNiQZpcgJQAgJsK6k= +golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/hmac.go b/hmac.go index a8640cfa..77e8dc58 100644 --- a/hmac.go +++ b/hmac.go @@ -74,7 +74,7 @@ func (h *HMAC) Write(data []byte) (n int, err error) { } func (h *HMAC) Reset() error { - if 1 != C.X_HMAC_Init_ex(h.ctx, nil, 0, nil, nil) { + if C.X_HMAC_Init_ex(h.ctx, nil, 0, nil, nil) != 1 { return errors.New("failed to reset HMAC_CTX") } return nil diff --git a/hostname.go b/hostname.go index c92d959e..9ef4ba29 100644 --- a/hostname.go +++ b/hostname.go @@ -17,18 +17,26 @@ package openssl /* #include #include -#include +#if OPENSSL_VERSION_NUMBER >= 0x30000000L + #include + typedef const char x509char; +#else + #include -#ifndef X509_CHECK_FLAG_ALWAYS_CHECK_SUBJECT -#define X509_CHECK_FLAG_ALWAYS_CHECK_SUBJECT 0x1 -#define X509_CHECK_FLAG_NO_WILDCARDS 0x2 + #ifndef X509_CHECK_FLAG_ALWAYS_CHECK_SUBJECT + #define X509_CHECK_FLAG_ALWAYS_CHECK_SUBJECT 0x1 + #define X509_CHECK_FLAG_NO_WILDCARDS 0x2 -extern int X509_check_host(X509 *x, const unsigned char *chk, size_t chklen, - unsigned int flags, char **peername); -extern int X509_check_email(X509 *x, const unsigned char *chk, size_t chklen, - unsigned int flags); -extern int X509_check_ip(X509 *x, const unsigned char *chk, size_t chklen, - unsigned int flags); + extern int X509_check_host(X509 *x, const unsigned char *chk, size_t chklen, + unsigned int flags, char **peername); + extern int X509_check_email(X509 *x, const unsigned char *chk, size_t chklen, + unsigned int flags); + extern int X509_check_ip(X509 *x, const unsigned char *chk, size_t chklen, + unsigned int flags); + typedef const unsigned char x509char; + #else + typedef const char x509char; + #endif #endif */ import "C" @@ -40,7 +48,7 @@ import ( ) var ( - ValidationError = errors.New("Host validation error") + ValidationError = errors.New("host validation error") //lint:ignore ST1012 rename may cause breaking changes; research before renaming. ) type CheckFlags int @@ -59,7 +67,7 @@ func (c *Certificate) CheckHost(host string, flags CheckFlags) error { chost := unsafe.Pointer(C.CString(host)) defer C.free(chost) - rv := C.X509_check_host(c.x, (*C.uchar)(chost), C.size_t(len(host)), + rv := C.X509_check_host(c.x, (*C.x509char)(chost), C.size_t(len(host)), C.uint(flags), nil) if rv > 0 { return nil @@ -78,7 +86,7 @@ func (c *Certificate) CheckHost(host string, flags CheckFlags) error { func (c *Certificate) CheckEmail(email string, flags CheckFlags) error { cemail := unsafe.Pointer(C.CString(email)) defer C.free(cemail) - rv := C.X509_check_email(c.x, (*C.uchar)(cemail), C.size_t(len(email)), + rv := C.X509_check_email(c.x, (*C.x509char)(cemail), C.size_t(len(email)), C.uint(flags)) if rv > 0 { return nil diff --git a/init.go b/init.go index 17dc6f38..b8c7a0da 100644 --- a/init.go +++ b/init.go @@ -18,65 +18,69 @@ Package openssl is a light wrapper around OpenSSL for Go. It strives to provide a near-drop-in replacement for the Go standard library tls package, while allowing for: -Performance +# Performance OpenSSL is battle-tested and optimized C. While Go's built-in library shows great promise, it is still young and in some places, inefficient. This simple OpenSSL wrapper can often do at least 2x with the same cipher and protocol. On my lappytop, I get the following benchmarking speeds: - BenchmarkSHA1Large_openssl 1000 2611282 ns/op 401.56 MB/s - BenchmarkSHA1Large_stdlib 500 3963983 ns/op 264.53 MB/s - BenchmarkSHA1Small_openssl 1000000 3476 ns/op 0.29 MB/s - BenchmarkSHA1Small_stdlib 5000000 550 ns/op 1.82 MB/s - BenchmarkSHA256Large_openssl 200 8085314 ns/op 129.69 MB/s - BenchmarkSHA256Large_stdlib 100 18948189 ns/op 55.34 MB/s - BenchmarkSHA256Small_openssl 1000000 4262 ns/op 0.23 MB/s - BenchmarkSHA256Small_stdlib 1000000 1444 ns/op 0.69 MB/s - BenchmarkOpenSSLThroughput 100000 21634 ns/op 47.33 MB/s - BenchmarkStdlibThroughput 50000 58974 ns/op 17.36 MB/s - -Interoperability + + BenchmarkSHA1Large_openssl 1000 2611282 ns/op 401.56 MB/s + BenchmarkSHA1Large_stdlib 500 3963983 ns/op 264.53 MB/s + BenchmarkSHA1Small_openssl 1000000 3476 ns/op 0.29 MB/s + BenchmarkSHA1Small_stdlib 5000000 550 ns/op 1.82 MB/s + BenchmarkSHA256Large_openssl 200 8085314 ns/op 129.69 MB/s + BenchmarkSHA256Large_stdlib 100 18948189 ns/op 55.34 MB/s + BenchmarkSHA256Small_openssl 1000000 4262 ns/op 0.23 MB/s + BenchmarkSHA256Small_stdlib 1000000 1444 ns/op 0.69 MB/s + BenchmarkOpenSSLThroughput 100000 21634 ns/op 47.33 MB/s + BenchmarkStdlibThroughput 50000 58974 ns/op 17.36 MB/s + +# Interoperability Many systems support OpenSSL with a variety of plugins and modules for things, such as hardware acceleration in embedded devices. -Greater flexibility and configuration +# Greater flexibility and configuration OpenSSL allows for far greater configuration of corner cases and backwards compatibility (such as support of SSLv2). You shouldn't be using SSLv2 if you can help but, but sometimes you can't help it. -Security +# Security Yeah yeah, Heartbleed. But according to the author of the standard library's TLS implementation, Go's TLS library is vulnerable to timing attacks. And whether or not OpenSSL received the appropriate amount of scrutiny pre-Heartbleed, it sure is receiving it now. -Usage +# Usage Starting an HTTP server that uses OpenSSL is very easy. It's as simple as: - log.Fatal(openssl.ListenAndServeTLS( - ":8443", "my_server.crt", "my_server.key", myHandler)) + + log.Fatal(openssl.ListenAndServeTLS( + ":8443", "my_server.crt", "my_server.key", myHandler)) Getting a net.Listener that uses OpenSSL is also easy: - ctx, err := openssl.NewCtxFromFiles("my_server.crt", "my_server.key") - if err != nil { - log.Fatal(err) - } - l, err := openssl.Listen("tcp", ":7777", ctx) + + ctx, err := openssl.NewCtxFromFiles("my_server.crt", "my_server.key") + if err != nil { + log.Fatal(err) + } + l, err := openssl.Listen("tcp", ":7777", ctx) Making a client connection is straightforward too: - ctx, err := NewCtx() - if err != nil { - log.Fatal(err) - } - err = ctx.LoadVerifyLocations("/etc/ssl/certs/ca-certificates.crt", "") - if err != nil { - log.Fatal(err) - } - conn, err := openssl.Dial("tcp", "localhost:7777", ctx, 0) + + ctx, err := NewCtx() + if err != nil { + log.Fatal(err) + } + err = ctx.LoadVerifyLocations("/etc/ssl/certs/ca-certificates.crt", "") + if err != nil { + log.Fatal(err) + } + conn, err := openssl.Dial("tcp", "localhost:7777", ctx, 0) Help wanted: To get this library to work with net/http's client, we had to fork net/http. It would be nice if an alternate http client library @@ -88,14 +92,13 @@ package openssl import "C" import ( - "errors" "fmt" "strings" ) func init() { if rc := C.X_shim_init(); rc != 0 { - panic(fmt.Errorf("X_shim_init failed with %d", rc)) + panic(fmt.Errorf("x_shim_init failed with %d", rc)) } } @@ -113,5 +116,5 @@ func errorFromErrorQueue() error { C.GoString(C.ERR_func_error_string(err)), C.GoString(C.ERR_reason_error_string(err)))) } - return errors.New(fmt.Sprintf("SSL errors: %s", strings.Join(errs, "\n"))) + return fmt.Errorf("SSL errors: %s", strings.Join(errs, "\n")) } diff --git a/init_posix.go b/init_posix.go index 2da7f957..8c4ffdfc 100644 --- a/init_posix.go +++ b/init_posix.go @@ -12,8 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build linux darwin solaris -// +build !windows +//go:build (linux || darwin || solaris || freebsd || openbsd) && !windows package openssl diff --git a/init_windows.go b/init_windows.go index 051133c3..22c7e126 100644 --- a/init_windows.go +++ b/init_windows.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build windows +//go:build windows package openssl diff --git a/key.go b/key.go index 91ea98a7..cb685d8a 100644 --- a/key.go +++ b/key.go @@ -19,7 +19,7 @@ import "C" import ( "errors" - "io/ioutil" + "io" "runtime" "unsafe" ) @@ -85,6 +85,12 @@ type PublicKey interface { // `KeyType() == KeyTypeRSA2` would both have `BaseType() == KeyTypeRSA`. BaseType() NID + // Equal compares the key with the passed in key. + Equal(key PublicKey) bool + + // Size returns the size (in bytes) of signatures created with this key. + Size() int + evpPKey() *C.EVP_PKEY } @@ -109,10 +115,18 @@ type pKey struct { func (key *pKey) evpPKey() *C.EVP_PKEY { return key.key } +func (key *pKey) Equal(other PublicKey) bool { + return C.EVP_PKEY_cmp(key.key, other.evpPKey()) == 1 +} + func (key *pKey) KeyType() NID { return NID(C.EVP_PKEY_id(key.key)) } +func (key *pKey) Size() int { + return int(C.EVP_PKEY_size(key.key)) +} + func (key *pKey) BaseType() NID { return NID(C.EVP_PKEY_base_id(key.key)) } @@ -129,36 +143,36 @@ func (key *pKey) SignPKCS1v15(method Method, data []byte) ([]byte, error) { return nil, errors.New("signpkcs1v15: 0-length data or non-null digest") } - if 1 != C.X_EVP_DigestSignInit(ctx, nil, nil, nil, key.key) { + if C.X_EVP_DigestSignInit(ctx, nil, nil, nil, key.key) != 1 { return nil, errors.New("signpkcs1v15: failed to init signature") } // evp signatures are 64 bytes - sig := make([]byte, 64, 64) + sig := make([]byte, 64) var sigblen C.size_t = 64 - if 1 != C.X_EVP_DigestSign(ctx, - ((*C.uchar)(unsafe.Pointer(&sig[0]))), + if C.X_EVP_DigestSign(ctx, + (*C.uchar)(unsafe.Pointer(&sig[0])), &sigblen, (*C.uchar)(unsafe.Pointer(&data[0])), - C.size_t(len(data))) { + C.size_t(len(data))) != 1 { return nil, errors.New("signpkcs1v15: failed to do one-shot signature") } return sig[:sigblen], nil } else { - if 1 != C.X_EVP_SignInit(ctx, method) { + if C.X_EVP_SignInit(ctx, method) != 1 { return nil, errors.New("signpkcs1v15: failed to init signature") } if len(data) > 0 { - if 1 != C.X_EVP_SignUpdate( - ctx, unsafe.Pointer(&data[0]), C.uint(len(data))) { + if C.X_EVP_SignUpdate( + ctx, unsafe.Pointer(&data[0]), C.uint(len(data))) != 1 { return nil, errors.New("signpkcs1v15: failed to update signature") } } sig := make([]byte, C.X_EVP_PKEY_size(key.key)) var sigblen C.uint - if 1 != C.X_EVP_SignFinal(ctx, - ((*C.uchar)(unsafe.Pointer(&sig[0]))), &sigblen, key.key) { + if C.X_EVP_SignFinal(ctx, + (*C.uchar)(unsafe.Pointer(&sig[0])), &sigblen, key.key) != 1 { return nil, errors.New("signpkcs1v15: failed to finalize signature") } return sig[:sigblen], nil @@ -169,39 +183,43 @@ func (key *pKey) VerifyPKCS1v15(method Method, data, sig []byte) error { ctx := C.X_EVP_MD_CTX_new() defer C.X_EVP_MD_CTX_free(ctx) + if len(sig) == 0 { + return errors.New("verifypkcs1v15: 0-length sig") + } + if key.KeyType() == KeyTypeED25519 { // do ED specific one-shot sign - if method != nil || len(data) == 0 || len(sig) == 0 { - return errors.New("verifypkcs1v15: 0-length data or sig or non-null digest") + if method != nil || len(data) == 0 { + return errors.New("verifypkcs1v15: 0-length data or non-null digest") } - if 1 != C.X_EVP_DigestVerifyInit(ctx, nil, nil, nil, key.key) { + if C.X_EVP_DigestVerifyInit(ctx, nil, nil, nil, key.key) != 1 { return errors.New("verifypkcs1v15: failed to init verify") } - if 1 != C.X_EVP_DigestVerify(ctx, - ((*C.uchar)(unsafe.Pointer(&sig[0]))), + if C.X_EVP_DigestVerify(ctx, + (*C.uchar)(unsafe.Pointer(&sig[0])), C.size_t(len(sig)), (*C.uchar)(unsafe.Pointer(&data[0])), - C.size_t(len(data))) { + C.size_t(len(data))) != 1 { return errors.New("verifypkcs1v15: failed to do one-shot verify") } return nil } else { - if 1 != C.X_EVP_VerifyInit(ctx, method) { + if C.X_EVP_VerifyInit(ctx, method) != 1 { return errors.New("verifypkcs1v15: failed to init verify") } if len(data) > 0 { - if 1 != C.X_EVP_VerifyUpdate( - ctx, unsafe.Pointer(&data[0]), C.uint(len(data))) { + if C.X_EVP_VerifyUpdate( + ctx, unsafe.Pointer(&data[0]), C.uint(len(data))) != 1 { return errors.New("verifypkcs1v15: failed to update verify") } } - if 1 != C.X_EVP_VerifyFinal(ctx, - ((*C.uchar)(unsafe.Pointer(&sig[0]))), C.uint(len(sig)), key.key) { + if C.X_EVP_VerifyFinal(ctx, + (*C.uchar)(unsafe.Pointer(&sig[0])), C.uint(len(sig)), key.key) != 1 { return errors.New("verifypkcs1v15: failed to finalize verify") } return nil @@ -224,7 +242,7 @@ func (key *pKey) MarshalPKCS1PrivateKeyPEM() (pem_block []byte, return nil, errors.New("failed dumping private key") } - return ioutil.ReadAll(asAnyBio(bio)) + return io.ReadAll(asAnyBio(bio)) } func (key *pKey) MarshalPKCS1PrivateKeyDER() (der_block []byte, @@ -239,7 +257,7 @@ func (key *pKey) MarshalPKCS1PrivateKeyDER() (der_block []byte, return nil, errors.New("failed dumping private key der") } - return ioutil.ReadAll(asAnyBio(bio)) + return io.ReadAll(asAnyBio(bio)) } func (key *pKey) MarshalPKIXPublicKeyPEM() (pem_block []byte, @@ -254,7 +272,7 @@ func (key *pKey) MarshalPKIXPublicKeyPEM() (pem_block []byte, return nil, errors.New("failed dumping public key pem") } - return ioutil.ReadAll(asAnyBio(bio)) + return io.ReadAll(asAnyBio(bio)) } func (key *pKey) MarshalPKIXPublicKeyDER() (der_block []byte, @@ -269,7 +287,7 @@ func (key *pKey) MarshalPKIXPublicKeyDER() (der_block []byte, return nil, errors.New("failed dumping public key der") } - return ioutil.ReadAll(asAnyBio(bio)) + return io.ReadAll(asAnyBio(bio)) } // LoadPrivateKeyFromPEM loads a private key from a PEM-encoded block. diff --git a/key_test.go b/key_test.go index 753e3784..25de69c5 100644 --- a/key_test.go +++ b/key_test.go @@ -22,7 +22,7 @@ import ( "crypto/x509" "encoding/hex" pem_pkg "encoding/pem" - "io/ioutil" + "os" "testing" ) @@ -36,6 +36,10 @@ func TestMarshal(t *testing.T) { t.Fatal(err) } + if !key.Equal(key) { + t.Fatal("key not equal to itself") + } + privateBlock, _ := pem_pkg.Decode(keyBytes) key, err = LoadPrivateKeyFromDER(privateBlock.Bytes) if err != nil { @@ -47,8 +51,8 @@ func TestMarshal(t *testing.T) { t.Fatal(err) } if !bytes.Equal(pem, certBytes) { - ioutil.WriteFile("generated", pem, 0644) - ioutil.WriteFile("hardcoded", certBytes, 0644) + os.WriteFile("generated", pem, 0644) + os.WriteFile("hardcoded", certBytes, 0644) t.Fatal("invalid cert pem bytes") } @@ -57,8 +61,8 @@ func TestMarshal(t *testing.T) { t.Fatal(err) } if !bytes.Equal(pem, keyBytes) { - ioutil.WriteFile("generated", pem, 0644) - ioutil.WriteFile("hardcoded", keyBytes, 0644) + os.WriteFile("generated", pem, 0644) + os.WriteFile("hardcoded", keyBytes, 0644) t.Fatal("invalid private key pem bytes") } tls_cert, err := tls.X509KeyPair(certBytes, keyBytes) @@ -90,8 +94,8 @@ func TestMarshal(t *testing.T) { t.Fatal(err) } if !bytes.Equal(der, tls_der) { - ioutil.WriteFile("generated", []byte(hex.Dump(der)), 0644) - ioutil.WriteFile("hardcoded", []byte(hex.Dump(tls_der)), 0644) + os.WriteFile("generated", []byte(hex.Dump(der)), 0644) + os.WriteFile("hardcoded", []byte(hex.Dump(tls_der)), 0644) t.Fatal("invalid public key der bytes") } @@ -102,8 +106,8 @@ func TestMarshal(t *testing.T) { tls_pem := pem_pkg.EncodeToMemory(&pem_pkg.Block{ Type: "PUBLIC KEY", Bytes: tls_der}) if !bytes.Equal(pem, tls_pem) { - ioutil.WriteFile("generated", pem, 0644) - ioutil.WriteFile("hardcoded", tls_pem, 0644) + os.WriteFile("generated", pem, 0644) + os.WriteFile("hardcoded", tls_pem, 0644) t.Fatal("invalid public key pem bytes") } @@ -128,14 +132,14 @@ func TestMarshal(t *testing.T) { } if !bytes.Equal(new_der_from_der, tls_der) { - ioutil.WriteFile("generated", []byte(hex.Dump(new_der_from_der)), 0644) - ioutil.WriteFile("hardcoded", []byte(hex.Dump(tls_der)), 0644) + os.WriteFile("generated", []byte(hex.Dump(new_der_from_der)), 0644) + os.WriteFile("hardcoded", []byte(hex.Dump(tls_der)), 0644) t.Fatal("invalid public key der bytes") } if !bytes.Equal(new_der_from_pem, tls_der) { - ioutil.WriteFile("generated", []byte(hex.Dump(new_der_from_pem)), 0644) - ioutil.WriteFile("hardcoded", []byte(hex.Dump(tls_der)), 0644) + os.WriteFile("generated", []byte(hex.Dump(new_der_from_pem)), 0644) + os.WriteFile("hardcoded", []byte(hex.Dump(tls_der)), 0644) t.Fatal("invalid public key der bytes") } } @@ -183,12 +187,7 @@ func TestGenerateEd25519(t *testing.T) { if err != nil { t.Fatal(err) } - _, err = key.MarshalPKIXPublicKeyPEM() - if err != nil { - t.Fatal(err) - } - _, err = key.MarshalPKCS1PrivateKeyPEM() - if err != nil { + if _, err = key.MarshalPKIXPublicKeyPEM(); err != nil { t.Fatal(err) } } @@ -285,8 +284,7 @@ func TestSignED25519(t *testing.T) { } func TestMarshalEC(t *testing.T) { - key, err := LoadPrivateKeyFromPEM(prime256v1KeyBytes) - if err != nil { + if _, err := LoadPrivateKeyFromPEM(prime256v1KeyBytes); err != nil { t.Fatal(err) } cert, err := LoadCertificateFromPEM(prime256v1CertBytes) @@ -295,7 +293,7 @@ func TestMarshalEC(t *testing.T) { } privateBlock, _ := pem_pkg.Decode(prime256v1KeyBytes) - key, err = LoadPrivateKeyFromDER(privateBlock.Bytes) + key, err := LoadPrivateKeyFromDER(privateBlock.Bytes) if err != nil { t.Fatal(err) } @@ -305,8 +303,8 @@ func TestMarshalEC(t *testing.T) { t.Fatal(err) } if !bytes.Equal(pem, prime256v1CertBytes) { - ioutil.WriteFile("generated", pem, 0644) - ioutil.WriteFile("hardcoded", prime256v1CertBytes, 0644) + os.WriteFile("generated", pem, 0644) + os.WriteFile("hardcoded", prime256v1CertBytes, 0644) t.Fatal("invalid cert pem bytes") } @@ -315,8 +313,8 @@ func TestMarshalEC(t *testing.T) { t.Fatal(err) } if !bytes.Equal(pem, prime256v1KeyBytes) { - ioutil.WriteFile("generated", pem, 0644) - ioutil.WriteFile("hardcoded", prime256v1KeyBytes, 0644) + os.WriteFile("generated", pem, 0644) + os.WriteFile("hardcoded", prime256v1KeyBytes, 0644) t.Fatal("invalid private key pem bytes") } tls_cert, err := tls.X509KeyPair(prime256v1CertBytes, prime256v1KeyBytes) @@ -351,8 +349,8 @@ func TestMarshalEC(t *testing.T) { t.Fatal(err) } if !bytes.Equal(der, tls_der) { - ioutil.WriteFile("generated", []byte(hex.Dump(der)), 0644) - ioutil.WriteFile("hardcoded", []byte(hex.Dump(tls_der)), 0644) + os.WriteFile("generated", []byte(hex.Dump(der)), 0644) + os.WriteFile("hardcoded", []byte(hex.Dump(tls_der)), 0644) t.Fatal("invalid public key der bytes") } @@ -363,8 +361,8 @@ func TestMarshalEC(t *testing.T) { tls_pem := pem_pkg.EncodeToMemory(&pem_pkg.Block{ Type: "PUBLIC KEY", Bytes: tls_der}) if !bytes.Equal(pem, tls_pem) { - ioutil.WriteFile("generated", pem, 0644) - ioutil.WriteFile("hardcoded", tls_pem, 0644) + os.WriteFile("generated", pem, 0644) + os.WriteFile("hardcoded", tls_pem, 0644) t.Fatal("invalid public key pem bytes") } @@ -389,14 +387,14 @@ func TestMarshalEC(t *testing.T) { } if !bytes.Equal(new_der_from_der, tls_der) { - ioutil.WriteFile("generated", []byte(hex.Dump(new_der_from_der)), 0644) - ioutil.WriteFile("hardcoded", []byte(hex.Dump(tls_der)), 0644) + os.WriteFile("generated", []byte(hex.Dump(new_der_from_der)), 0644) + os.WriteFile("hardcoded", []byte(hex.Dump(tls_der)), 0644) t.Fatal("invalid public key der bytes") } if !bytes.Equal(new_der_from_pem, tls_der) { - ioutil.WriteFile("generated", []byte(hex.Dump(new_der_from_pem)), 0644) - ioutil.WriteFile("hardcoded", []byte(hex.Dump(tls_der)), 0644) + os.WriteFile("generated", []byte(hex.Dump(new_der_from_pem)), 0644) + os.WriteFile("hardcoded", []byte(hex.Dump(tls_der)), 0644) t.Fatal("invalid public key der bytes") } } @@ -406,8 +404,7 @@ func TestMarshalEd25519(t *testing.T) { t.SkipNow() } - key, err := LoadPrivateKeyFromPEM(ed25519KeyBytes) - if err != nil { + if _, err := LoadPrivateKeyFromPEM(ed25519KeyBytes); err != nil { t.Fatal(err) } cert, err := LoadCertificateFromPEM(ed25519CertBytes) @@ -416,7 +413,7 @@ func TestMarshalEd25519(t *testing.T) { } privateBlock, _ := pem_pkg.Decode(ed25519KeyBytes) - key, err = LoadPrivateKeyFromDER(privateBlock.Bytes) + key, err := LoadPrivateKeyFromDER(privateBlock.Bytes) if err != nil { t.Fatal(err) } @@ -426,22 +423,18 @@ func TestMarshalEd25519(t *testing.T) { t.Fatal(err) } if !bytes.Equal(pem, ed25519CertBytes) { - ioutil.WriteFile("generated", pem, 0644) - ioutil.WriteFile("hardcoded", ed25519CertBytes, 0644) + os.WriteFile("generated", pem, 0644) + os.WriteFile("hardcoded", ed25519CertBytes, 0644) t.Fatal("invalid cert pem bytes") } - pem, err = key.MarshalPKCS1PrivateKeyPEM() - if err != nil { - t.Fatal(err) - } + // NOTE: Ed25519 cannot be marshalled to PEM. - der, err := key.MarshalPKCS1PrivateKeyDER() - if err != nil { + if _, err := key.MarshalPKCS1PrivateKeyDER(); err != nil { t.Fatal(err) } - der, err = key.MarshalPKIXPublicKeyDER() + der, err := key.MarshalPKIXPublicKeyDER() if err != nil { t.Fatal(err) } @@ -451,22 +444,22 @@ func TestMarshalEd25519(t *testing.T) { t.Fatal(err) } - loaded_pubkey_from_pem, err := LoadPublicKeyFromPEM(pem) + loadedPubkeyFromPem, err := LoadPublicKeyFromPEM(pem) if err != nil { t.Fatal(err) } - loaded_pubkey_from_der, err := LoadPublicKeyFromDER(der) + loadedPubkeyFromDer, err := LoadPublicKeyFromDER(der) if err != nil { t.Fatal(err) } - _, err = loaded_pubkey_from_pem.MarshalPKIXPublicKeyDER() + _, err = loadedPubkeyFromPem.MarshalPKIXPublicKeyDER() if err != nil { t.Fatal(err) } - _, err = loaded_pubkey_from_der.MarshalPKIXPublicKeyDER() + _, err = loadedPubkeyFromDer.MarshalPKIXPublicKeyDER() if err != nil { t.Fatal(err) } diff --git a/md4.go b/md4.go index e5cc7d86..95d9d2d2 100644 --- a/md4.go +++ b/md4.go @@ -51,8 +51,11 @@ func (s *MD4Hash) Close() { } func (s *MD4Hash) Reset() error { - if 1 != C.X_EVP_DigestInit_ex(s.ctx, C.X_EVP_md4(), engineRef(s.engine)) { - return errors.New("openssl: md4: cannot init digest ctx") + runtime.LockOSThread() + defer runtime.UnlockOSThread() + if C.X_EVP_DigestInit_ex(s.ctx, C.X_EVP_md4(), engineRef(s.engine)) != 1 { + return errors.New("openssl: md4: cannot init digest ctx: " + + errorFromErrorQueue().Error()) } return nil } @@ -61,16 +64,16 @@ func (s *MD4Hash) Write(p []byte) (n int, err error) { if len(p) == 0 { return 0, nil } - if 1 != C.X_EVP_DigestUpdate(s.ctx, unsafe.Pointer(&p[0]), - C.size_t(len(p))) { + if C.X_EVP_DigestUpdate(s.ctx, unsafe.Pointer(&p[0]), + C.size_t(len(p))) != 1 { return 0, errors.New("openssl: md4: cannot update digest") } return len(p), nil } func (s *MD4Hash) Sum() (result [16]byte, err error) { - if 1 != C.X_EVP_DigestFinal_ex(s.ctx, - (*C.uchar)(unsafe.Pointer(&result[0])), nil) { + if C.X_EVP_DigestFinal_ex(s.ctx, + (*C.uchar)(unsafe.Pointer(&result[0])), nil) != 1 { return result, errors.New("openssl: md4: cannot finalize ctx") } return result, s.Reset() diff --git a/md4_test.go b/md4_test.go index b31c7e64..9041ebfe 100644 --- a/md4_test.go +++ b/md4_test.go @@ -56,7 +56,19 @@ var md4Examples = []struct{ out, in string }{ {"6e593341e62194911d5cc31e39835f27", "c5e4bc73821faa34adf9468441ffd97520a96cd5debda4d51edcaaf2b23fbd"}, } +func skipIfMD4Unsupported(t testing.TB) { + t.Helper() + + hash, err := NewMD4Hash() + if err != nil { + t.Skip("MD4 is not supported by OpenSSL") + } + hash.Close() +} + func TestMD4Examples(t *testing.T) { + skipIfMD4Unsupported(t) + for _, ex := range md4Examples { buf, err := hex.DecodeString(ex.in) if err != nil { @@ -75,6 +87,8 @@ func TestMD4Examples(t *testing.T) { } func TestMD4Writer(t *testing.T) { + skipIfMD4Unsupported(t) + ohash, err := NewMD4Hash() if err != nil { t.Fatal(err) @@ -120,9 +134,13 @@ func benchmarkMD4(b *testing.B, length int64, fn md4func) { } func BenchmarkMD4Large_openssl(b *testing.B) { + skipIfMD4Unsupported(b) + benchmarkMD4(b, 1024*1024, func(buf []byte) { MD4(buf) }) } func BenchmarkMD4Small_openssl(b *testing.B) { + skipIfMD4Unsupported(b) + benchmarkMD4(b, 1, func(buf []byte) { MD4(buf) }) } diff --git a/md5.go b/md5.go index 82f2eb2f..d7e771ee 100644 --- a/md5.go +++ b/md5.go @@ -51,7 +51,7 @@ func (s *MD5Hash) Close() { } func (s *MD5Hash) Reset() error { - if 1 != C.X_EVP_DigestInit_ex(s.ctx, C.X_EVP_md5(), engineRef(s.engine)) { + if C.X_EVP_DigestInit_ex(s.ctx, C.X_EVP_md5(), engineRef(s.engine)) != 1 { return errors.New("openssl: md5: cannot init digest ctx") } return nil @@ -61,16 +61,16 @@ func (s *MD5Hash) Write(p []byte) (n int, err error) { if len(p) == 0 { return 0, nil } - if 1 != C.X_EVP_DigestUpdate(s.ctx, unsafe.Pointer(&p[0]), - C.size_t(len(p))) { + if C.X_EVP_DigestUpdate(s.ctx, unsafe.Pointer(&p[0]), + C.size_t(len(p))) != 1 { return 0, errors.New("openssl: md5: cannot update digest") } return len(p), nil } func (s *MD5Hash) Sum() (result [16]byte, err error) { - if 1 != C.X_EVP_DigestFinal_ex(s.ctx, - (*C.uchar)(unsafe.Pointer(&result[0])), nil) { + if C.X_EVP_DigestFinal_ex(s.ctx, + (*C.uchar)(unsafe.Pointer(&result[0])), nil) != 1 { return result, errors.New("openssl: md5: cannot finalize ctx") } return result, s.Reset() diff --git a/net.go b/net.go index 54beb8ee..b2293c7c 100644 --- a/net.go +++ b/net.go @@ -17,6 +17,7 @@ package openssl import ( "errors" "net" + "time" ) type listener struct { @@ -80,6 +81,18 @@ func Dial(network, addr string, ctx *Ctx, flags DialFlags) (*Conn, error) { return DialSession(network, addr, ctx, flags, nil) } +// DialTimeout acts like Dial but takes a timeout for network dial. +// +// The timeout includes only network dial. It does not include OpenSSL calls. +// +// See func Dial for a description of the network, addr, ctx and flags +// parameters. +func DialTimeout(network, addr string, timeout time.Duration, ctx *Ctx, + flags DialFlags) (*Conn, error) { + d := net.Dialer{Timeout: timeout} + return dialSession(d, network, addr, ctx, flags, nil) +} + // DialSession will connect to network/address and then wrap the corresponding // underlying connection with an OpenSSL client connection using context ctx. // If flags includes InsecureSkipHostVerification, the server certificate's @@ -95,7 +108,12 @@ func Dial(network, addr string, ctx *Ctx, flags DialFlags) (*Conn, error) { // can be retrieved from the GetSession method on the Conn. func DialSession(network, addr string, ctx *Ctx, flags DialFlags, session []byte) (*Conn, error) { + var d net.Dialer + return dialSession(d, network, addr, ctx, flags, session) +} +func dialSession(d net.Dialer, network, addr string, ctx *Ctx, flags DialFlags, + session []byte) (*Conn, error) { host, _, err := net.SplitHostPort(addr) if err != nil { return nil, err @@ -108,7 +126,8 @@ func DialSession(network, addr string, ctx *Ctx, flags DialFlags, } // TODO: use operating system default certificate chain? } - c, err := net.Dial(network, addr) + + c, err := d.Dial(network, addr) if err != nil { return nil, err } diff --git a/object.go b/object.go new file mode 100644 index 00000000..4d908e6c --- /dev/null +++ b/object.go @@ -0,0 +1,24 @@ +// Copyright (C) 2020. See AUTHORS. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package openssl + +// #include "shim.h" +import "C" + +// CreateObjectIdentifier creates ObjectIdentifier and returns NID for the created +// ObjectIdentifier +func CreateObjectIdentifier(oid string, shortName string, longName string) NID { + return NID(C.OBJ_create(C.CString(oid), C.CString(shortName), C.CString(longName))) +} diff --git a/pem.go b/pem.go index c8b0c1cf..6127cf07 100644 --- a/pem.go +++ b/pem.go @@ -19,14 +19,10 @@ import ( ) var pemSplit *regexp.Regexp = regexp.MustCompile(`(?sm)` + - `(^-----[\s-]*?BEGIN.*?-----$` + + `(^-----[\s-]*?BEGIN.*?-----[\s-]*?$` + `.*?` + - `^-----[\s-]*?END.*?-----$)`) + `^-----[\s-]*?END.*?-----[\s-]*?$)`) func SplitPEM(data []byte) [][]byte { - var results [][]byte - for _, block := range pemSplit.FindAll(data, -1) { - results = append(results, block) - } - return results + return pemSplit.FindAll(data, -1) } diff --git a/sha1.go b/sha1.go index c227bee8..ab4ad87f 100644 --- a/sha1.go +++ b/sha1.go @@ -58,7 +58,7 @@ func engineRef(e *Engine) *C.ENGINE { } func (s *SHA1Hash) Reset() error { - if 1 != C.X_EVP_DigestInit_ex(s.ctx, C.X_EVP_sha1(), engineRef(s.engine)) { + if C.X_EVP_DigestInit_ex(s.ctx, C.X_EVP_sha1(), engineRef(s.engine)) != 1 { return errors.New("openssl: sha1: cannot init digest ctx") } return nil @@ -68,16 +68,16 @@ func (s *SHA1Hash) Write(p []byte) (n int, err error) { if len(p) == 0 { return 0, nil } - if 1 != C.X_EVP_DigestUpdate(s.ctx, unsafe.Pointer(&p[0]), - C.size_t(len(p))) { + if C.X_EVP_DigestUpdate(s.ctx, unsafe.Pointer(&p[0]), + C.size_t(len(p))) != 1 { return 0, errors.New("openssl: sha1: cannot update digest") } return len(p), nil } func (s *SHA1Hash) Sum() (result [20]byte, err error) { - if 1 != C.X_EVP_DigestFinal_ex(s.ctx, - (*C.uchar)(unsafe.Pointer(&result[0])), nil) { + if C.X_EVP_DigestFinal_ex(s.ctx, + (*C.uchar)(unsafe.Pointer(&result[0])), nil) != 1 { return result, errors.New("openssl: sha1: cannot finalize ctx") } return result, s.Reset() diff --git a/sha256.go b/sha256.go index d25c7a95..d9189a94 100644 --- a/sha256.go +++ b/sha256.go @@ -51,7 +51,7 @@ func (s *SHA256Hash) Close() { } func (s *SHA256Hash) Reset() error { - if 1 != C.X_EVP_DigestInit_ex(s.ctx, C.X_EVP_sha256(), engineRef(s.engine)) { + if C.X_EVP_DigestInit_ex(s.ctx, C.X_EVP_sha256(), engineRef(s.engine)) != 1 { return errors.New("openssl: sha256: cannot init digest ctx") } return nil @@ -61,16 +61,16 @@ func (s *SHA256Hash) Write(p []byte) (n int, err error) { if len(p) == 0 { return 0, nil } - if 1 != C.X_EVP_DigestUpdate(s.ctx, unsafe.Pointer(&p[0]), - C.size_t(len(p))) { + if C.X_EVP_DigestUpdate(s.ctx, unsafe.Pointer(&p[0]), + C.size_t(len(p))) != 1 { return 0, errors.New("openssl: sha256: cannot update digest") } return len(p), nil } func (s *SHA256Hash) Sum() (result [32]byte, err error) { - if 1 != C.X_EVP_DigestFinal_ex(s.ctx, - (*C.uchar)(unsafe.Pointer(&result[0])), nil) { + if C.X_EVP_DigestFinal_ex(s.ctx, + (*C.uchar)(unsafe.Pointer(&result[0])), nil) != 1 { return result, errors.New("openssl: sha256: cannot finalize ctx") } return result, s.Reset() diff --git a/shim.c b/shim.c index 6e680841..b27a5743 100644 --- a/shim.c +++ b/shim.c @@ -428,7 +428,7 @@ int X_SSL_session_reused(SSL *ssl) { } int X_SSL_new_index() { - return SSL_get_ex_new_index(0, NULL, NULL, NULL, NULL); + return SSL_get_ex_new_index(0, NULL, NULL, NULL, go_ssl_crypto_ex_free); } int X_SSL_verify_cb(int ok, X509_STORE_CTX* store) { @@ -475,6 +475,14 @@ int X_SSL_CTX_new_index() { return SSL_CTX_get_ex_new_index(0, NULL, NULL, NULL, NULL); } +int X_SSL_CTX_set_min_proto_version(SSL_CTX *ctx, int version) { + return SSL_CTX_set_min_proto_version(ctx, version); +} + +int X_SSL_CTX_set_max_proto_version(SSL_CTX *ctx, int version) { + return SSL_CTX_set_max_proto_version(ctx, version); +} + long X_SSL_CTX_set_options(SSL_CTX* ctx, long options) { return SSL_CTX_set_options(ctx, options); } diff --git a/shim.h b/shim.h index b792822b..94fe8c61 100644 --- a/shim.h +++ b/shim.h @@ -29,6 +29,7 @@ #include #include + #ifndef SSL_MODE_RELEASE_BUFFERS #define SSL_MODE_RELEASE_BUFFERS 0 #endif @@ -66,6 +67,8 @@ extern int X_SSL_verify_cb(int ok, X509_STORE_CTX* store); /* SSL_CTX methods */ extern int X_SSL_CTX_new_index(); +extern int X_SSL_CTX_set_min_proto_version(SSL_CTX *ctx, int version); +extern int X_SSL_CTX_set_max_proto_version(SSL_CTX *ctx, int version); extern long X_SSL_CTX_set_options(SSL_CTX* ctx, long options); extern long X_SSL_CTX_clear_options(SSL_CTX* ctx, long options); extern long X_SSL_CTX_get_options(SSL_CTX* ctx); @@ -89,6 +92,8 @@ extern int X_SSL_CTX_set_tlsext_ticket_key_cb(SSL_CTX *sslctx, extern int X_SSL_CTX_ticket_key_cb(SSL *s, unsigned char key_name[16], unsigned char iv[EVP_MAX_IV_LENGTH], EVP_CIPHER_CTX *cctx, HMAC_CTX *hctx, int enc); +extern int SSL_CTX_set_alpn_protos(SSL_CTX *ctx, const unsigned char *protos, + unsigned int protos_len); /* BIO methods */ extern int X_BIO_get_flags(BIO *b); @@ -170,3 +175,10 @@ extern int X_X509_set_version(X509 *x, long version); /* PEM methods */ extern int X_PEM_write_bio_PrivateKey_traditional(BIO *bio, EVP_PKEY *key, const EVP_CIPHER *enc, unsigned char *kstr, int klen, pem_password_cb *cb, void *u); + +/* Object methods */ +extern int OBJ_create(const char *oid,const char *sn,const char *ln); + +/* Extension helper method */ +extern const unsigned char * get_extention(X509 *x, int NID, int *data_len); +extern int add_custom_ext(X509 *cert, int nid, char *value, int len); \ No newline at end of file diff --git a/ssl.go b/ssl.go index 117c30c0..b187d15d 100644 --- a/ssl.go +++ b/ssl.go @@ -20,6 +20,8 @@ import "C" import ( "os" "unsafe" + + "github.com/mattn/go-pointer" ) type SSLTLSExtErr int @@ -53,7 +55,7 @@ func go_ssl_verify_cb_thunk(p unsafe.Pointer, ok C.int, ctx *C.X509_STORE_CTX) C os.Exit(1) } }() - verify_cb := (*SSL)(p).verify_cb + verify_cb := pointer.Restore(p).(*SSL).verify_cb // set up defaults just in case verify_cb is nil if verify_cb != nil { store := &CertificateStoreCtx{ctx: ctx} @@ -159,11 +161,11 @@ func sni_cb_thunk(p unsafe.Pointer, con *C.SSL, ad unsafe.Pointer, arg unsafe.Po } }() - sni_cb := (*Ctx)(p).sni_cb + sni_cb := pointer.Restore(p).(*Ctx).sni_cb s := &SSL{ssl: con} // This attaches a pointer to our SSL struct into the SNI callback. - C.SSL_set_ex_data(s.ssl, get_ssl_idx(), unsafe.Pointer(s)) + C.SSL_set_ex_data(s.ssl, get_ssl_idx(), pointer.Save(s)) // Note: this is ctx.sni_cb, not C.sni_cb return C.int(sni_cb(s)) diff --git a/ssl_test.go b/ssl_test.go index a0bd9d50..bad56737 100644 --- a/ssl_test.go +++ b/ssl_test.go @@ -19,38 +19,35 @@ import ( "crypto/rand" "crypto/tls" "io" - "io/ioutil" "net" "sync" "testing" "time" - "github.com/spacemonkeygo/openssl/utils" + "github.com/libp2p/go-openssl/utils" ) var ( certBytes = []byte(`-----BEGIN CERTIFICATE----- -MIIDxDCCAqygAwIBAgIVAMcK/0VWQr2O3MNfJCydqR7oVELcMA0GCSqGSIb3DQEB -BQUAMIGQMUkwRwYDVQQDE0A1NjdjZGRmYzRjOWZiNTYwZTk1M2ZlZjA1N2M0NGFm -MDdiYjc4MDIzODIxYTA5NThiY2RmMGMwNzJhOTdiMThhMQswCQYDVQQGEwJVUzEN -MAsGA1UECBMEVXRhaDEQMA4GA1UEBxMHTWlkdmFsZTEVMBMGA1UEChMMU3BhY2Ug -TW9ua2V5MB4XDTEzMTIxNzE4MzgyMloXDTIzMTIxNTE4MzgyMlowgZAxSTBHBgNV -BAMTQDM4NTg3ODRkMjU1NTdiNTM1MWZmNjRmMmQzMTQ1ZjkwYTJlMTIzMDM4Y2Yz -Mjc1Yzg1OTM1MjcxYWIzMmNiMDkxCzAJBgNVBAYTAlVTMQ0wCwYDVQQIEwRVdGFo -MRAwDgYDVQQHEwdNaWR2YWxlMRUwEwYDVQQKEwxTcGFjZSBNb25rZXkwggEiMA0G -CSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDdf3icNvFsrlrnNLi8SocscqlSbFq+ -pEvmhcSoqgDLqebnqu8Ld73HJJ74MGXEgRX8xZT5FinOML31CR6t9E/j3dqV6p+G -fdlFLe3IqtC0/bPVnCDBirBygBI4uCrMq+1VhAxPWclrDo7l9QRYbsExH9lfn+Ry -vxeNMZiOASasvVZNncY8E9usBGRdH17EfDL/TPwXqWOLyxSN5o54GTztjjy9w9CG -QP7jcCueKYyQJQCtEmnwc6P/q6/EPv5R6drBkX6loAPtmCUAkHqxkWOJrRq/v7Pw -zRYhfY+ZpVHGc7WEkDnLzRiUypr1C9oxvLKS10etZEIwEdKyOkSg2fdPAgMBAAGj -EzARMA8GA1UdEwEB/wQFMAMCAQAwDQYJKoZIhvcNAQEFBQADggEBAEcz0RTTJ99l -HTK/zTyfV5VZEhtwqu6bwre/hD7lhI+1ji0DZYGIgCbJLKuZhj+cHn2h5nPhN7zE -M9tc4pn0TgeVS0SVFSe6TGnIFipNogvP17E+vXpDZcW/xn9kPKeVCZc1hlDt1W4Z -5q+ub3aUwuMwYs7bcArtDrumCmciJ3LFyNhebPi4mntb5ooeLFLaujEmVYyrQnpo -tWKC9sMlJmLm4yAso64Sv9KLS2T9ivJBNn0ZtougozBCCTqrqgZVjha+B2yjHe9f -sRkg/uxcJf7wC5Y0BLlp1+aPwdmZD87T3a1uQ1Ij93jmHG+2T9U20MklHAePOl0q -yTqdSPnSH1c= +MIIDazCCAlOgAwIBAgIUYYC8EshUsBUeU6IG2Fyr1Nr7KG0wDQYJKoZIhvcNAQEL +BQAwRTELMAkGA1UEBhMCVVMxDTALBgNVBAgMBFV0YWgxEDAOBgNVBAcMB01pZHZh +bGUxFTATBgNVBAoMDFNwYWNlIE1vbmtleTAeFw0yMTA4MTQxODIzNDFaFw0zMTA2 +MjMxODIzNDFaMEUxCzAJBgNVBAYTAlVTMQ0wCwYDVQQIDARVdGFoMRAwDgYDVQQH +DAdNaWR2YWxlMRUwEwYDVQQKDAxTcGFjZSBNb25rZXkwggEiMA0GCSqGSIb3DQEB +AQUAA4IBDwAwggEKAoIBAQDdf3icNvFsrlrnNLi8SocscqlSbFq+pEvmhcSoqgDL +qebnqu8Ld73HJJ74MGXEgRX8xZT5FinOML31CR6t9E/j3dqV6p+GfdlFLe3IqtC0 +/bPVnCDBirBygBI4uCrMq+1VhAxPWclrDo7l9QRYbsExH9lfn+RyvxeNMZiOASas +vVZNncY8E9usBGRdH17EfDL/TPwXqWOLyxSN5o54GTztjjy9w9CGQP7jcCueKYyQ +JQCtEmnwc6P/q6/EPv5R6drBkX6loAPtmCUAkHqxkWOJrRq/v7PwzRYhfY+ZpVHG +c7WEkDnLzRiUypr1C9oxvLKS10etZEIwEdKyOkSg2fdPAgMBAAGjUzBRMB0GA1Ud +DgQWBBSj8Z6d2TqacRP4allwQM1FYgltPzAfBgNVHSMEGDAWgBSj8Z6d2TqacRP4 +allwQM1FYgltPzAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQA2 +KJLoFWorZz+tb/HdDJTTDxy5/XhOnx+2OIALFsLJnulo8fHbJnPKspe2V08EcFZ0 +hUrvKsaXpm8VXX21yOFg5yMcrG6A3voQWIjvTCNwfywnpnsxrWwhuRqioUmR4WSW +NoFuwg+lt6bLDavM4Izl86Nb/LoAzKc6g6nKGHKJLuJma6RPJnmjfC4Os1GWf7rf +kQOP/XdA0t+JW1+ABBdOd5kOtowAvQLKzLYi6xTrvEDSjDtiKS42dVydBpj3Uaih +tCzcieQbb6KqUyxxzgTelXq2IxJUyU74Jv96BZ8cA7Qvwv1jwsfxYv7VHLuFAmtW +KCDFmLjMtdrKX+q5zJe7 -----END CERTIFICATE----- `) keyBytes = []byte(`-----BEGIN RSA PRIVATE KEY----- @@ -193,18 +190,16 @@ func SimpleConnTest(t testing.TB, constructor func( } buf := bytes.NewBuffer(make([]byte, 0, len(data))) - _, err = io.CopyN(buf, server, int64(len(data))) + _, err = io.Copy(buf, server) if err != nil { t.Fatal(err) } - if string(buf.Bytes()) != data { + if buf.String() != data { t.Fatal("mismatched data") } - err = server.Close() - if err != nil { - t.Fatal(err) - } + // Only one side gets a clean close because closing needs to write a terminator. + _ = server.Close() }() wg.Wait() } @@ -223,10 +218,10 @@ func close_both(closer1, closer2 io.Closer) { wg.Wait() } -func ClosingTest(t testing.TB, constructor func( +func ClosingTest(t *testing.T, constructor func( t testing.TB, conn1, conn2 net.Conn) (sslconn1, sslconn2 HandshakingConn)) { - run_test := func(close_tcp bool, server_writes bool) { + run_test := func(t *testing.T, close_tcp bool, server_writes bool) { server_conn, client_conn := NetPipe(t) defer server_conn.Close() defer client_conn.Close() @@ -246,12 +241,34 @@ func ClosingTest(t testing.TB, constructor func( } var wg sync.WaitGroup + + // If we're killing the TCP connection, make sure we handshake first + if close_tcp { + wg.Add(2) + go func() { + defer wg.Done() + err := sslconn1.Handshake() + if err != nil { + t.Error(err) + } + }() + go func() { + defer wg.Done() + err := sslconn2.Handshake() + if err != nil { + t.Error(err) + } + }() + wg.Wait() + } + wg.Add(2) go func() { defer wg.Done() _, err := sslconn1.Write([]byte("hello")) if err != nil { - t.Fatal(err) + t.Error(err) + return } if close_tcp { err = conn1.Close() @@ -259,28 +276,37 @@ func ClosingTest(t testing.TB, constructor func( err = sslconn1.Close() } if err != nil { - t.Fatal(err) + t.Error(err) } }() go func() { defer wg.Done() - data, err := ioutil.ReadAll(sslconn2) - if err != nil { - t.Fatal(err) - } + data, err := io.ReadAll(sslconn2) if !bytes.Equal(data, []byte("hello")) { - t.Fatal("bytes don't match") + t.Error("bytes don't match") + } + if !close_tcp && err != nil { + t.Error(err) + return } }() wg.Wait() } - run_test(true, false) - run_test(false, false) - run_test(true, true) - run_test(false, true) + t.Run("close TCP, server reads", func(t *testing.T) { + run_test(t, true, false) + }) + t.Run("close SSL, server reads", func(t *testing.T) { + run_test(t, false, false) + }) + t.Run("close TCP, server writes", func(t *testing.T) { + run_test(t, true, true) + }) + t.Run("close SSL, server writes", func(t *testing.T) { + run_test(t, false, true) + }) } func ThroughputBenchmark(b *testing.B, constructor func( @@ -304,21 +330,21 @@ func ThroughputBenchmark(b *testing.B, constructor func( wg.Add(2) go func() { defer wg.Done() - _, err = io.Copy(client, bytes.NewReader([]byte(data))) - if err != nil { - b.Fatal(err) + if _, err = io.Copy(client, bytes.NewReader(data)); err != nil { + b.Error(err) + return } }() go func() { defer wg.Done() buf := &bytes.Buffer{} - _, err = io.CopyN(buf, server, int64(len(data))) - if err != nil { - b.Fatal(err) + if _, err = io.CopyN(buf, server, int64(len(data))); err != nil { + b.Error(err) + return } if !bytes.Equal(buf.Bytes(), data) { - b.Fatal("mismatched data") + b.Error("mismatched data") } }() wg.Wait() @@ -551,27 +577,27 @@ func LotsOfConns(t *testing.T, payload_size int64, loops, clients int, for { conn, err := ssl_listener.Accept() if err != nil { - t.Fatalf("failed accept: %s", err) + t.Errorf("failed accept: %s", err) continue } go func() { defer func() { err = conn.Close() if err != nil { - t.Fatalf("failed closing: %s", err) + t.Errorf("failed closing: %s", err) } }() for i := 0; i < loops; i++ { - _, err := io.Copy(ioutil.Discard, + _, err := io.Copy(io.Discard, io.LimitReader(conn, payload_size)) if err != nil { - t.Fatalf("failed reading: %s", err) + t.Errorf("failed reading: %s", err) return } _, err = io.Copy(conn, io.LimitReader(rand.Reader, payload_size)) if err != nil { - t.Fatalf("failed writing: %s", err) + t.Errorf("failed writing: %s", err) return } } @@ -581,35 +607,37 @@ func LotsOfConns(t *testing.T, payload_size int64, loops, clients int, }() var wg sync.WaitGroup for i := 0; i < clients; i++ { - tcp_client, err := net.Dial(tcp_listener.Addr().Network(), + tcpClient, err := net.Dial(tcp_listener.Addr().Network(), tcp_listener.Addr().String()) if err != nil { - t.Fatal(err) + t.Error(err) + return } - ssl_client, err := newClient(tcp_client) + ssl_client, err := newClient(tcpClient) if err != nil { - t.Fatal(err) + t.Error(err) + return } wg.Add(1) go func(i int) { + defer wg.Done() defer func() { err = ssl_client.Close() if err != nil { - t.Fatalf("failed closing: %s", err) + t.Errorf("failed closing: %s", err) } - wg.Done() }() for i := 0; i < loops; i++ { _, err := io.Copy(ssl_client, io.LimitReader(rand.Reader, payload_size)) if err != nil { - t.Fatalf("failed writing: %s", err) + t.Errorf("failed writing: %s", err) return } - _, err = io.Copy(ioutil.Discard, + _, err = io.Copy(io.Discard, io.LimitReader(ssl_client, payload_size)) if err != nil { - t.Fatalf("failed reading: %s", err) + t.Errorf("failed reading: %s", err) return } } @@ -645,20 +673,17 @@ func TestOpenSSLLotsOfConns(t *testing.T) { if err != nil { t.Fatal(err) } - err = ctx.UsePrivateKey(key) - if err != nil { + if err = ctx.UsePrivateKey(key); err != nil { t.Fatal(err) } cert, err := LoadCertificateFromPEM(certBytes) if err != nil { t.Fatal(err) } - err = ctx.UseCertificate(cert) - if err != nil { + if err = ctx.UseCertificate(cert); err != nil { t.Fatal(err) } - err = ctx.SetCipherList("AES128-SHA") - if err != nil { + if err = ctx.SetCipherList("AES128-SHA"); err != nil { t.Fatal(err) } LotsOfConns(t, 1024*64, 10, 100, 0*time.Second, diff --git a/tickets.go b/tickets.go index a064d385..2ee8ed9b 100644 --- a/tickets.go +++ b/tickets.go @@ -20,6 +20,8 @@ import "C" import ( "os" "unsafe" + + "github.com/mattn/go-pointer" ) const ( @@ -127,7 +129,7 @@ func go_ticket_key_cb_thunk(p unsafe.Pointer, s *C.SSL, key_name *C.uchar, } }() - ctx := (*Ctx)(p) + ctx := pointer.Restore(p).(*Ctx) store := ctx.ticket_store if store == nil { // TODO(jeff): should this be an error condition? it doesn't make sense diff --git a/utils/future.go b/utils/future.go index fa1bbbfb..df2d8312 100644 --- a/utils/future.go +++ b/utils/future.go @@ -45,35 +45,35 @@ func NewFuture() *Future { } // Get blocks until the Future has a value set. -func (self *Future) Get() (interface{}, error) { - self.mutex.Lock() - defer self.mutex.Unlock() +func (f *Future) Get() (interface{}, error) { + f.mutex.Lock() + defer f.mutex.Unlock() for { - if self.received { - return self.val, self.err + if f.received { + return f.val, f.err } - self.cond.Wait() + f.cond.Wait() } } // Fired returns whether or not a value has been set. If Fired is true, Get // won't block. -func (self *Future) Fired() bool { - self.mutex.Lock() - defer self.mutex.Unlock() - return self.received +func (f *Future) Fired() bool { + f.mutex.Lock() + defer f.mutex.Unlock() + return f.received } // Set provides the value to present and future Get calls. If Set has already // been called, this is a no-op. -func (self *Future) Set(val interface{}, err error) { - self.mutex.Lock() - defer self.mutex.Unlock() - if self.received { +func (f *Future) Set(val interface{}, err error) { + f.mutex.Lock() + defer f.mutex.Unlock() + if f.received { return } - self.received = true - self.val = val - self.err = err - self.cond.Broadcast() + f.received = true + f.val = val + f.err = err + f.cond.Broadcast() } diff --git a/version.json b/version.json new file mode 100644 index 00000000..557859c5 --- /dev/null +++ b/version.json @@ -0,0 +1,3 @@ +{ + "version": "v0.1.0" +}