diff --git a/internal/openfga/interfaces.go b/internal/openfga/interfaces.go index b5f7bfd5d..3e562a7d2 100644 --- a/internal/openfga/interfaces.go +++ b/internal/openfga/interfaces.go @@ -39,3 +39,7 @@ type OpenFGAClientInterface interface { DeleteTuples(context.Context, ...Tuple) error Check(context.Context, string, string, string, ...Tuple) (bool, error) } + +type ListPermissionsFiltersInterface interface { + WithFilter() any +} diff --git a/internal/openfga/stores.go b/internal/openfga/stores.go index aad2c6db0..6d8652612 100644 --- a/internal/openfga/stores.go +++ b/internal/openfga/stores.go @@ -242,7 +242,7 @@ func (s *OpenFGAStore) ListPermissions(ctx context.Context, ID string, continuat for _, t := range s.permissionTypes() { s.wpool.Submit( - s.listPermissionsFunc(ctx, ID, t, continuationTokens[t]), + s.listPermissionsFunc(ctx, ID, "", t, continuationTokens[t]), results, &wg, ) @@ -282,11 +282,92 @@ func (s *OpenFGAStore) ListPermissions(ctx context.Context, ID string, continuat return permissions, tMap, fmt.Errorf(eMsg) } -func (s *OpenFGAStore) listPermissionsFunc(ctx context.Context, ID, ofgaType, cToken string) func() any { +// ListPermissionsWithFilters returns all the permissions associated to a specific entity +func (s *OpenFGAStore) ListPermissionsWithFilters(ctx context.Context, ID string, opts ...ListPermissionsFiltersInterface) ([]Permission, map[string]string, error) { + ctx, span := s.tracer.Start(ctx, "openfga.OpenFGAStore.ListPermissionsWithFilters") + defer span.End() + + // keep it a buffered channel, if set to unbuffered we would need a goroutine + // to consume from it before pushing to it + // https://go.dev/ref/spec#Send_statements + // A send on an unbuffered channel can proceed if a receiver is ready. + // A send on a buffered channel can proceed if there is room in the buffer + results := make(chan *pool.Result[any], len(s.permissionTypes())) + + ff := new(listPermissionsOpts) + + if len(opts) != 0 { + ff = s.parseFilters(opts...) + } + + types := s.permissionTypes() + tokenMap := make(map[string]string) + + if tm := ff.TokenMap; tm != nil { + tokenMap = tm + } + + if tf := ff.TypesFilter; len(tf) > 0 { + types = tf + } + + wg := sync.WaitGroup{} + wg.Add(len(types)) + + for _, t := range types { + token, ok := tokenMap[t] + + if !ok { + token = "" + } + + s.wpool.Submit( + s.listPermissionsFunc(ctx, ID, ff.RelationFilter, t, token), + results, + &wg, + ) + } + + // wait for tasks to finish + wg.Wait() + + // close result channel + close(results) + + permissions := make([]Permission, 0) + tMap := make(map[string]string) + errors := make([]error, 0) + + for r := range results { + v := r.Value.(listPermissionsResult) + permissions = append(permissions, v.permissions...) + tMap[v.ofgaType] = v.token + + if v.err != nil { + errors = append(errors, v.err) + } + } + + if len(errors) == 0 { + return permissions, tMap, nil + } + + eMsg := "" + + for n, e := range errors { + s.logger.Errorf(e.Error()) + eMsg = fmt.Sprintf("%s%v - %s\n", eMsg, n, e.Error()) + } + + return permissions, tMap, fmt.Errorf(eMsg) +} + +func (s *OpenFGAStore) listPermissionsFunc(ctx context.Context, ID, relation, ofgaType, cToken string) func() any { return func() any { p, token, err := s.listPermissionsByType( ctx, ID, + relation, ofgaType, cToken, ) @@ -300,11 +381,11 @@ func (s *OpenFGAStore) listPermissionsFunc(ctx context.Context, ID, ofgaType, cT } } -func (s *OpenFGAStore) listPermissionsByType(ctx context.Context, ID, pType, continuationToken string) ([]Permission, string, error) { +func (s *OpenFGAStore) listPermissionsByType(ctx context.Context, ID, relation, pType, continuationToken string) ([]Permission, string, error) { ctx, span := s.tracer.Start(ctx, "openfga.OpenFGAStore.listPermissionsByType") defer span.End() - r, err := s.ofga.ReadTuples(ctx, ID, "", fmt.Sprintf("%s:", pType), continuationToken) + r, err := s.ofga.ReadTuples(ctx, ID, relation, fmt.Sprintf("%s:", pType), continuationToken) if err != nil { s.logger.Error(err.Error()) @@ -325,6 +406,52 @@ func (s *OpenFGAStore) listPermissionsByType(ctx context.Context, ID, pType, con return permissions, r.GetContinuationToken(), nil } +func (s *OpenFGAStore) parseFilters(filters ...ListPermissionsFiltersInterface) *listPermissionsOpts { + opts := new(listPermissionsOpts) + opts.TokenMap = make(map[string]string) + opts.TypesFilter = make([]string, 0) + + // this will keep only the latest filter passed in, if 2 type filters are passed, last one is kept + for _, filter := range filters { + switch f := filter.(type) { + case *TypesFilter: + if f == nil { + continue + } + + if v, ok := f.WithFilter().([]string); ok { + opts.TypesFilter = v + } else { + s.logger.Errorf("wrong types filter, casting failed: %v", f) + } + case *RelationFilter: + if f == nil { + continue + } + + if v, ok := f.WithFilter().(string); ok { + opts.RelationFilter = v + } else { + s.logger.Errorf("wrong relation filter, casting failed: %s", f) + } + case *TokenMapFilter: + if f == nil { + continue + } + + if v, ok := f.WithFilter().(map[string]string); ok { + opts.TokenMap = v + } else { + s.logger.Errorf("wrong token map, casting failed: %v", f) + } + default: + continue + } + } + + return opts +} + func (s *OpenFGAStore) permissionTypes() []string { return []string{"group", "role", "identity", "scheme", "provider", "client"} } diff --git a/internal/openfga/stores_test.go b/internal/openfga/stores_test.go index a5e0b2d90..795e8c20c 100644 --- a/internal/openfga/stores_test.go +++ b/internal/openfga/stores_test.go @@ -875,12 +875,12 @@ func TestStoreListPermissions(t *testing.T) { } expPermissions := []Permission{ - Permission{Relation: "can_edit", Object: "role:test"}, - Permission{Relation: "can_edit", Object: "group:test"}, - Permission{Relation: "can_edit", Object: "identity:test"}, - Permission{Relation: "can_edit", Object: "scheme:test"}, - Permission{Relation: "can_edit", Object: "provider:test"}, - Permission{Relation: "can_edit", Object: "client:test"}, + {Relation: "can_edit", Object: "role:test"}, + {Relation: "can_edit", Object: "group:test"}, + {Relation: "can_edit", Object: "identity:test"}, + {Relation: "can_edit", Object: "scheme:test"}, + {Relation: "can_edit", Object: "provider:test"}, + {Relation: "can_edit", Object: "client:test"}, } calls := []*gomock.Call{} @@ -961,3 +961,174 @@ func TestStoreListPermissions(t *testing.T) { }) } } + +func TestStoreListPermissionsWithPermissions(t *testing.T) { + type input struct { + ID string + relationFilter *RelationFilter + typesFilter *TypesFilter + tokenMapFilter *TokenMapFilter + } + + tests := []struct { + name string + input input + expected error + }{ + { + name: "error", + input: input{ + ID: "role:administrator#assignee", + }, + expected: fmt.Errorf("error"), + }, + { + name: "role found", + input: input{ + ID: "role:administrator#assignee", + relationFilter: NewRelationFilter("can_edit"), + tokenMapFilter: NewTokenMapFilter( + map[string]string{"role": "test"}, + ), + }, + expected: nil, + }, + { + name: "group found", + input: input{ + ID: "group:administrator#member", + typesFilter: NewTypesFilter("identity", "client"), + tokenMapFilter: NewTokenMapFilter( + map[string]string{"role": "test"}, + ), + }, + expected: nil, + }, + { + name: "user found", + input: input{ + ID: "use:joe", + tokenMapFilter: NewTokenMapFilter( + map[string]string{"role": "test"}, + ), + }, + expected: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockLogger := NewMockLoggerInterface(ctrl) + mockTracer := NewMockTracer(ctrl) + mockMonitor := monitoring.NewMockMonitorInterface(ctrl) + mockOpenFGA := NewMockOpenFGAClientInterface(ctrl) + mockWorkerPool := NewMockWorkerPoolInterface(ctrl) + + store := NewOpenFGAStore(mockOpenFGA, mockWorkerPool, mockTracer, mockMonitor, mockLogger) + + types := store.permissionTypes() + + if test.input.typesFilter != nil { + types = test.input.typesFilter.WithFilter().([]string) + } + + for i := 0; i < len(types); i++ { + setupMockSubmit(mockWorkerPool, nil) + } + + mockTracer.EXPECT().Start(gomock.Any(), gomock.Any()).AnyTimes().Return(context.TODO(), trace.SpanFromContext(context.TODO())) + mockLogger.EXPECT().Error(gomock.Any()).AnyTimes() + mockLogger.EXPECT().Errorf(gomock.Any()).AnyTimes() + + expCTokens := make(map[string]string) + expPermissions := make([]Permission, 0) + + for _, t := range types { + expPermissions = append( + expPermissions, + Permission{Relation: "can_edit", Object: t + ":test"}, + ) + expCTokens[t] = "" + } + + calls := []*gomock.Call{} + + for _, _ = range types { + calls = append( + calls, + mockOpenFGA.EXPECT().ReadTuples(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, user, relation, object, continuationToken string) (*client.ClientReadResponse, error) { + if test.expected != nil { + return nil, test.expected + } + + if user != test.input.ID { + t.Errorf("wrong user parameter expected %s got %s", test.input.ID, user) + } + + if object == "role:" && continuationToken != "test" { + tokenM, ok := test.input.tokenMapFilter.WithFilter().(map[string]string) + + if !ok { + t.Fatal("failed parsing token map") + } + + t.Errorf("missing continuation token %s", tokenM["roles"]) + } + + tuples := []openfga.Tuple{ + *openfga.NewTuple( + *openfga.NewTupleKey( + user, "can_edit", fmt.Sprintf("%stest", object), + ), + time.Now(), + ), + *openfga.NewTuple( + *openfga.NewTupleKey( + user, "assignee", "role:test", + ), + time.Now(), + ), + } + + r := new(client.ClientReadResponse) + r.SetContinuationToken("") + r.SetTuples(tuples) + + return r, nil + }, + ), + ) + } + + gomock.InAnyOrder(calls) + permissions, cTokens, err := store.ListPermissionsWithFilters(context.Background(), test.input.ID, test.input.typesFilter, test.input.tokenMapFilter, test.input.relationFilter) + + if err != nil && test.expected == nil { + t.Fatalf("expected error to be silenced and return nil got %v instead", err) + } + + sortFx := func(a, b Permission) int { + if n := strings.Compare(a.Relation, b.Relation); n != 0 { + return n + } + // If relations are equal, order by object + return cmp.Compare(a.Object, b.Object) + } + + slices.SortFunc(permissions, sortFx) + slices.SortFunc(expPermissions, sortFx) + + if err == nil && test.expected == nil && !reflect.DeepEqual(permissions, expPermissions) { + t.Fatalf("expected permissions to be %v got %v", expPermissions, permissions) + } + + if err == nil && test.expected == nil && !reflect.DeepEqual(cTokens, expCTokens) { + t.Fatalf("expected continuation tokens to be %v got %v", expCTokens, cTokens) + } + }) + } +} diff --git a/internal/openfga/types.go b/internal/openfga/types.go index 9cb1873f1..634cfcff4 100644 --- a/internal/openfga/types.go +++ b/internal/openfga/types.go @@ -37,3 +37,60 @@ func NewTuple(user, relation, object string) *Tuple { return t } + +type TokenMapFilter struct { + tokens map[string]string +} + +func (f *TokenMapFilter) WithFilter() any { + return f.tokens +} + +func NewTokenMapFilter(tokens map[string]string) *TokenMapFilter { + f := new(TokenMapFilter) + f.tokens = tokens + + return f +} + +type TypesFilter struct { + resourceTypes []string +} + +func (f *TypesFilter) WithFilter() any { + return f.resourceTypes +} + +func NewTypesFilter(resourceTypes ...string) *TypesFilter { + f := new(TypesFilter) + + f.resourceTypes = make([]string, 0) + + for _, r := range resourceTypes { + f.resourceTypes = append(f.resourceTypes, r) + } + + return f +} + +type RelationFilter struct { + relation string +} + +func (f *RelationFilter) WithFilter() any { + return f.relation +} + +func NewRelationFilter(relation string) *RelationFilter { + f := new(RelationFilter) + + f.relation = relation + + return f +} + +type listPermissionsOpts struct { + TokenMap map[string]string + TypesFilter []string + RelationFilter string +}