Skip to content

Add a generic FGA retriever that can be used with LLM providers directly #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
286 changes: 261 additions & 25 deletions examples/authorization-for-rag/langchain-examples/poetry.lock

Large diffs are not rendered by default.

Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def create_retriever(user: str):
base_retriever,
build_query=lambda node: ClientBatchCheckItem(
user=f"user:{user}",
object=f"doc:{node.ref_doc_id}",
object=f"doc:{node.node.ref_doc_id}",
relation="viewer",
),
)
Expand Down
218 changes: 215 additions & 3 deletions examples/authorization-for-rag/llama-index-examples/poetry.lock

Large diffs are not rendered by default.

Empty file.
43 changes: 43 additions & 0 deletions packages/auth0-ai/auth0_ai/authorizers/fga/fga_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import os

from typing import Optional
from openfga_sdk import OpenFgaClient, ClientConfiguration
from openfga_sdk.sync import OpenFgaClient as OpenFgaClientSync

from openfga_sdk.credentials import CredentialConfiguration, Credentials


def get_openfga_client_configuration(
fga_client_configuration: Optional[ClientConfiguration],
):
"""
Returns a ClientConfiguration object based on the user's params, or the default values read from environment
"""
return fga_client_configuration or ClientConfiguration(
api_url=os.getenv("FGA_API_URL") or "https://api.us1.fga.dev",
store_id=os.getenv("FGA_STORE_ID"),
credentials=Credentials(
method="client_credentials",
configuration=CredentialConfiguration(
api_issuer=os.getenv("FGA_API_TOKEN_ISSUER") or "auth.fga.dev",
api_audience=os.getenv("FGA_API_AUDIENCE")
or "https://api.us1.fga.dev/",
client_id=os.getenv("FGA_CLIENT_ID"),
client_secret=os.getenv("FGA_CLIENT_SECRET"),
),
),
)


def build_openfga_client_sync(fga_client_configuration: Optional[ClientConfiguration]):
"""
Returns an instance of the sync OpenFGA Client
"""
return OpenFgaClientSync(get_openfga_client_configuration(fga_client_configuration))


def build_openfga_client(fga_client_configuration: Optional[ClientConfiguration]):
"""
Returns an instance of the async OpenFGA Client
"""
return OpenFgaClient(get_openfga_client_configuration(fga_client_configuration))
135 changes: 135 additions & 0 deletions packages/auth0-ai/auth0_ai/authorizers/fga/fga_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from typing import Any, Callable, Optional
from auth0_ai.authorizers.fga.fga_client import (
build_openfga_client,
build_openfga_client_sync,
)
from openfga_sdk import ClientConfiguration
from openfga_sdk.client.models import ClientBatchCheckItem
from openfga_sdk.client.client import ClientBatchCheckRequest


class FGAFilter[T]:
_query_builder: Callable[[T], ClientBatchCheckItem]
_fga_configuration: Optional[ClientConfiguration]

def __init__(
self,
query_builder: Callable[[T], ClientBatchCheckItem],
fga_configuration: Optional[ClientConfiguration],
):
"""
Initialize the FGAFilter with the specified query builder, and FGA parameters.

Args:
query_builder (Callable[[T], ClientBatchCheckItem]): Function to convert documents into FGA queries.
fga_configuration (Optional[ClientConfiguration]): Configuration for the OpenFGA client. If not provided, defaults to environment variables.
"""
self._fga_configuration = fga_configuration
self._query_builder = query_builder

async def filter(
self,
documents: list[T],
hash_document: Optional[Callable[[T], Any]] = None,
) -> list[T]:
"""
Asynchronously filter documents using OpenFGA.

Args:
documents (List[T]): List of documents to filter.
hash_document (Optional[Callable[T], Any]]: The filter function hashes the documents during the process. In some contexts, documents are not hasheable. You can provide this function which receives the document and returns a hashable, unique representation of the document.

Returns:
List[T]: Filtered list of documents authorized by FGA.
"""
if len(documents) == 0:
return []

async with build_openfga_client(self._fga_configuration) as fga_client:
all_checks = [self._query_builder(doc) for doc in documents]
unique_checks = list(
{
(check.relation, check.object, check.user): check
for check in all_checks
}.values()
)

def default_hash_document(doc):
return doc

if not hash_document:
hash_document = default_hash_document

doc_to_obj = {
hash_document(doc): check.object
for check, doc in zip(all_checks, documents)
}

fga_response = await fga_client.batch_check(
ClientBatchCheckRequest(checks=unique_checks)
)
await fga_client.close()

permissions_map = {
result.request.object: result.allowed for result in fga_response.result
}

return [
doc
for doc in documents
if doc_to_obj[hash_document(doc)] in permissions_map
and permissions_map[doc_to_obj[hash_document(doc)]]
]

def filter_sync(
self,
documents: list[T],
hash_document: Optional[Callable[[T], Any]] = None,
) -> list[T]:
"""
Synchronously filter documents using OpenFGA.

Args:
documents (List[T]): List of documents to filter.
hash_document (Optional[Callable[T], Any]]: The filter function hashes the documents during the process. In some contexts, documents are not hasheable. You can provide this function which receives the document and returns a hashable, unique representation of the document.

Returns:
List[T]: Filtered list of documents authorized by FGA.
"""
if len(documents) == 0:
return []

with build_openfga_client_sync(self._fga_configuration) as fga_client:
all_checks = [self._query_builder(doc) for doc in documents]
unique_checks = list(
{
(check.relation, check.object, check.user): check
for check in all_checks
}.values()
)

def default_hash_document(doc):
return doc

if not hash_document:
hash_document = default_hash_document

doc_to_obj = {
hash_document(doc): check.object
for check, doc in zip(all_checks, documents)
}

fga_response = fga_client.batch_check(
ClientBatchCheckRequest(checks=unique_checks)
)

permissions_map = {
result.request.object: result.allowed for result in fga_response.result
}

return [
doc
for doc in documents
if doc_to_obj[hash_document(doc)] in permissions_map
and permissions_map[doc_to_obj[hash_document(doc)]]
]
Loading