From 223aa5bb6cf4f24193ad6c6037e1b88160474f2e Mon Sep 17 00:00:00 2001 From: Noah Dietz Date: Mon, 23 Sep 2024 12:28:38 -0700 Subject: [PATCH] fix(locations): make source info access concurrent safe (#1433) * fix(locations): make source info access concurrent safe * follow mutex hat pattern * add test and run unit tests with -race --- .github/workflows/ci.yaml | 2 +- locations/locations.go | 47 +++++++++++++++++++++++++++++-------- locations/locations_test.go | 28 ++++++++++++++++++++++ 3 files changed, 66 insertions(+), 11 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 9a7b67f4e..399db11a4 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -12,7 +12,7 @@ jobs: - uses: actions/setup-go@v5 with: go-version: "1.20" - - run: go test -p 1 ./... + - run: go test -race ./... lint: runs-on: ubuntu-latest steps: diff --git a/locations/locations.go b/locations/locations.go index 27e5dd965..7d8fadcd8 100644 --- a/locations/locations.go +++ b/locations/locations.go @@ -23,6 +23,8 @@ package locations import ( + "sync" + "github.com/jhump/protoreflect/desc" dpb "google.golang.org/protobuf/types/descriptorpb" ) @@ -37,12 +39,25 @@ func pathLocation(d desc.Descriptor, path ...int) *dpb.SourceCodeInfo_Location { return sourceInfoRegistry.sourceInfo(d.GetFile()).findLocation(fullPath) } -type sourceInfo map[string]*dpb.SourceCodeInfo_Location +type sourceInfo struct { + // infoMu protects the info map + infoMu sync.Mutex + info map[string]*dpb.SourceCodeInfo_Location +} + +func newSourceInfo() *sourceInfo { + return &sourceInfo{ + info: map[string]*dpb.SourceCodeInfo_Location{}, + } +} // findLocation returns the Location for a given path. -func (si sourceInfo) findLocation(path []int32) *dpb.SourceCodeInfo_Location { +func (si *sourceInfo) findLocation(path []int32) *dpb.SourceCodeInfo_Location { + si.infoMu.Lock() + defer si.infoMu.Unlock() + // If the path exists in the source info registry, return that object. - if loc, ok := si[strPath(path)]; ok { + if loc, ok := si.info[strPath(path)]; ok { return loc } @@ -53,7 +68,17 @@ func (si sourceInfo) findLocation(path []int32) *dpb.SourceCodeInfo_Location { // The source map registry is a singleton that computes a source map for // any file descriptor that it is given, but then caches it to avoid computing // the source map for the same file descriptors over and over. -type sourceInfoRegistryType map[*desc.FileDescriptor]sourceInfo +type sourceInfoRegistryType struct { + // registryMu protects the registry map + registryMu sync.Mutex + registry map[*desc.FileDescriptor]*sourceInfo +} + +func newSourceInfoRegistryType() *sourceInfoRegistryType { + return &sourceInfoRegistryType{ + registry: map[*desc.FileDescriptor]*sourceInfo{}, + } +} // Each location has a path defined as an []int32, but we can not // use slices as keys, so compile them into a string. @@ -70,22 +95,24 @@ func strPath(segments []int32) (p string) { // sourceInfo compiles the source info object for a given file descriptor. // It also caches this into a registry, so subsequent calls using the same // descriptor will return the same object. -func (sir sourceInfoRegistryType) sourceInfo(fd *desc.FileDescriptor) sourceInfo { - answer, ok := sir[fd] +func (sir *sourceInfoRegistryType) sourceInfo(fd *desc.FileDescriptor) *sourceInfo { + sir.registryMu.Lock() + defer sir.registryMu.Unlock() + answer, ok := sir.registry[fd] if !ok { - answer = sourceInfo{} + answer = newSourceInfo() // This file descriptor does not yet have a source info map. // Compile one. for _, loc := range fd.AsFileDescriptorProto().GetSourceCodeInfo().GetLocation() { - answer[strPath(loc.Path)] = loc + answer.info[strPath(loc.Path)] = loc } // Now that we calculated all of this, cache it on the registry so it // does not need to be calculated again. - sir[fd] = answer + sir.registry[fd] = answer } return answer } -var sourceInfoRegistry = sourceInfoRegistryType{} +var sourceInfoRegistry = newSourceInfoRegistryType() diff --git a/locations/locations_test.go b/locations/locations_test.go index 2ada999a4..e4afe555f 100644 --- a/locations/locations_test.go +++ b/locations/locations_test.go @@ -16,6 +16,7 @@ package locations import ( "strings" + "sync" "testing" "github.com/jhump/protoreflect/desc" @@ -47,3 +48,30 @@ func parse(t *testing.T, s string) *desc.FileDescriptor { } return fds[0] } + +func TestSourceInfo_Concurrency(t *testing.T) { + fd := parse(t, ` + syntax = "proto3"; + package foo.bar; + `) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + FileSyntax(fd) + }() + + wg.Add(1) + go func() { + defer wg.Done() + FilePackage(fd) + }() + + wg.Add(1) + go func() { + defer wg.Done() + FileImport(fd, 0) + }() + wg.Wait() +}