From 8e01db5277341f4e6e9b8bffc07b5c5f86e29680 Mon Sep 17 00:00:00 2001 From: smndtrl Date: Wed, 21 Aug 2024 08:54:08 +0000 Subject: [PATCH 1/4] initial sso with oidc --- internal/api/api.go | 24 + internal/api/context.go | 47 +- internal/api/external.go | 5 + internal/api/helpers.go | 1 + internal/api/provider/generic.go | 280 +++++++++++ internal/api/samlacs.go | 2 +- internal/api/sso.go | 76 +-- internal/api/sso_oidc.go | 124 +++++ internal/api/sso_saml.go | 55 ++ internal/api/ssoadmin.go | 4 +- internal/api/ssooidcadmin.go | 476 ++++++++++++++++++ internal/conf/configuration.go | 10 + internal/models/factor.go | 5 + internal/models/sso.go | 163 +++++- internal/models/sso_test.go | 18 +- migrations/20240819081613_add_oidc_sso.up.sql | 49 ++ 16 files changed, 1260 insertions(+), 79 deletions(-) create mode 100644 internal/api/provider/generic.go create mode 100644 internal/api/sso_oidc.go create mode 100644 internal/api/sso_saml.go create mode 100644 internal/api/ssooidcadmin.go create mode 100644 migrations/20240819081613_add_oidc_sso.up.sql diff --git a/internal/api/api.go b/internal/api/api.go index 85292775f..44e3b2f6c 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -276,6 +276,17 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne }).SetBurst(30), )).Post("/acs", api.SAMLACS) }) + + r.Route("/oidc", func(r *router) { + r.Route("/callback", func(r *router) { + r.Use(api.isValidExternalHost) + r.Use(api.loadSSOOIDCFlowState) + + r.Get("/", api.ExternalProviderCallback) + r.Post("/", api.ExternalProviderCallback) + }) + }) + }) r.Route("/admin", func(r *router) { @@ -320,6 +331,19 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne r.Put("/", api.adminSSOProvidersUpdate) r.Delete("/", api.adminSSOProvidersDelete) }) + + r.Route("/oidc", func(r *router) { + r.Get("/", api.adminOIDCSSOProvidersList) + r.Post("/", api.adminOIDCSSOProvidersCreate) + + r.Route("/{idp_id}", func(r *router) { + r.Use(api.loadOIDCSSOProvider) + + r.Get("/", api.adminOIDCSSOProvidersGet) + // r.Put("/", api.adminOIDCSSOProvidersUpdate) + r.Delete("/", api.adminOIDCSSOProvidersDelete) + }) + }) }) }) diff --git a/internal/api/context.go b/internal/api/context.go index 3047f3dd6..f7738e17c 100644 --- a/internal/api/context.go +++ b/internal/api/context.go @@ -5,6 +5,7 @@ import ( "net/url" jwt "github.com/golang-jwt/jwt/v5" + "github.com/supabase/auth/internal/conf" "github.com/supabase/auth/internal/models" ) @@ -15,22 +16,23 @@ func (c contextKey) String() string { } const ( - tokenKey = contextKey("jwt") - inviteTokenKey = contextKey("invite_token") - signatureKey = contextKey("signature") - externalProviderTypeKey = contextKey("external_provider_type") - userKey = contextKey("user") - targetUserKey = contextKey("target_user") - factorKey = contextKey("factor") - sessionKey = contextKey("session") - externalReferrerKey = contextKey("external_referrer") - functionHooksKey = contextKey("function_hooks") - adminUserKey = contextKey("admin_user") - oauthTokenKey = contextKey("oauth_token") // for OAuth1.0, also known as request token - oauthVerifierKey = contextKey("oauth_verifier") - ssoProviderKey = contextKey("sso_provider") - externalHostKey = contextKey("external_host") - flowStateKey = contextKey("flow_state_id") + tokenKey = contextKey("jwt") + inviteTokenKey = contextKey("invite_token") + signatureKey = contextKey("signature") + externalProviderTypeKey = contextKey("external_provider_type") + userKey = contextKey("user") + targetUserKey = contextKey("target_user") + factorKey = contextKey("factor") + sessionKey = contextKey("session") + externalReferrerKey = contextKey("external_referrer") + functionHooksKey = contextKey("function_hooks") + adminUserKey = contextKey("admin_user") + oauthTokenKey = contextKey("oauth_token") // for OAuth1.0, also known as request token + oauthVerifierKey = contextKey("oauth_verifier") + ssoProviderKey = contextKey("sso_provider") + externalHostKey = contextKey("external_host") + flowStateKey = contextKey("flow_state_id") + genericProviderConfigKey = contextKey("generic_provider_config") ) // withToken adds the JWT token to the context. @@ -241,3 +243,16 @@ func getExternalHost(ctx context.Context) *url.URL { } return obj.(*url.URL) } + +func withGenericProviderConfig(ctx context.Context, token *conf.GenericOAuthProviderConfiguration) context.Context { + return context.WithValue(ctx, genericProviderConfigKey, token) +} + +func getGenericProviderConfig(ctx context.Context) *conf.GenericOAuthProviderConfiguration { + obj := ctx.Value(genericProviderConfigKey) + if obj == nil { + return nil + } + + return obj.(*conf.GenericOAuthProviderConfiguration) +} diff --git a/internal/api/external.go b/internal/api/external.go index ef6032d9a..27c798711 100644 --- a/internal/api/external.go +++ b/internal/api/external.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "log" "net/http" "net/url" "strconv" @@ -250,6 +251,7 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re // This means that the callback is using PKCE // Set the flowState.AuthCode to the query param here rurl, err = a.prepPKCERedirectURL(rurl, flowState.AuthCode) + log.Println("rurl", rurl) if err != nil { return err } @@ -568,6 +570,9 @@ func (a *API) Provider(ctx context.Context, name string, scopes string) (provide return provider.NewWorkOSProvider(config.External.WorkOS) case "zoom": return provider.NewZoomProvider(config.External.Zoom) + case "sso/oidc": + config := getGenericProviderConfig(ctx) + return provider.NewGenericProvider(*config, scopes) default: return nil, fmt.Errorf("Provider %s could not be found", name) } diff --git a/internal/api/helpers.go b/internal/api/helpers.go index 692139252..96f3b22d2 100644 --- a/internal/api/helpers.go +++ b/internal/api/helpers.go @@ -65,6 +65,7 @@ func getBodyBytes(req *http.Request) ([]byte, error) { type RequestParams interface { AdminUserParams | CreateSSOProviderParams | + CreateOIDCSSOProviderParams | EnrollFactorParams | GenerateLinkParams | IdTokenGrantParams | diff --git a/internal/api/provider/generic.go b/internal/api/provider/generic.go new file mode 100644 index 000000000..25c136340 --- /dev/null +++ b/internal/api/provider/generic.go @@ -0,0 +1,280 @@ +package provider + +import ( + "context" + "encoding/json" + "fmt" + "io" + "math" + "net/http" + "strconv" + "strings" + + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/utilities" + "golang.org/x/oauth2" +) + +type genericProvider struct { + *oauth2.Config + Issuer string + UserInfoURL string + UserDataMapping map[string]string +} + +func (p genericProvider) GetOAuthToken(code string) (*oauth2.Token, error) { + return p.Exchange(context.Background(), code) +} + +func (p genericProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { + var u map[string]interface{} + + // Perform http request manually, because we need to vary it based on the provider config + req, err := http.NewRequest("GET", p.UserInfoURL, nil) + + if err != nil { + return nil, err + } + + // set headers + req.Header.Set("Client-Id", p.ClientID) + req.Header.Set("Authorization", "Bearer "+tok.AccessToken) + + client := &http.Client{Timeout: defaultTimeout} + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer utilities.SafeClose(resp.Body) + + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + return nil, fmt.Errorf("a %v error occurred with retrieving user from OAuth2 provider via %s", resp.StatusCode, p.UserInfoURL) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + err = json.Unmarshal(body, &u) + if err != nil { + return nil, err + } + + // Read user data as specified in the JSON mapping + mapping := p.UserDataMapping + + email, err := getStringFieldByPath(u, mapping["Email"], "") + if err != nil { + return nil, err + } + + emailVerified, err := getBooleanFieldByPath(u, mapping["EmailVerified"], email != "") + if err != nil { + return nil, err + } + + emailPrimary, err := getBooleanFieldByPath(u, mapping["EmailPrimary"], email != "") + if err != nil { + return nil, err + } + + issuer, err := getStringFieldByPath(u, mapping["Issuer"], p.Issuer) + if err != nil { + return nil, err + } + + subject, err := getStringFieldByPath(u, mapping["Subject"], "") + if err != nil { + return nil, err + } + + name, err := getStringFieldByPath(u, mapping["Name"], "") + if err != nil { + return nil, err + } + + familyName, err := getStringFieldByPath(u, mapping["FamilyName"], "") + if err != nil { + return nil, err + } + + givenName, err := getStringFieldByPath(u, mapping["GivenName"], "") + if err != nil { + return nil, err + } + + middleName, err := getStringFieldByPath(u, mapping["MiddleName"], "") + if err != nil { + return nil, err + } + + nickName, err := getStringFieldByPath(u, mapping["NickName"], "") + if err != nil { + return nil, err + } + + preferredUsername, err := getStringFieldByPath(u, mapping["PreferredUsername"], "") + if err != nil { + return nil, err + } + + profile, err := getStringFieldByPath(u, mapping["Profile"], "") + if err != nil { + return nil, err + } + + picture, err := getStringFieldByPath(u, mapping["Picture"], "") + if err != nil { + return nil, err + } + + website, err := getStringFieldByPath(u, mapping["Website"], "") + if err != nil { + return nil, err + } + + gender, err := getStringFieldByPath(u, mapping["Gender"], "") + if err != nil { + return nil, err + } + + birthdate, err := getStringFieldByPath(u, mapping["Birthdate"], "") + if err != nil { + return nil, err + } + + zoneInfo, err := getStringFieldByPath(u, mapping["ZoneInfo"], "") + if err != nil { + return nil, err + } + + locale, err := getStringFieldByPath(u, mapping["Locale"], "") + if err != nil { + return nil, err + } + + updatedAt, err := getStringFieldByPath(u, mapping["UpdatedAt"], "") + if err != nil { + return nil, err + } + + phone, err := getStringFieldByPath(u, mapping["Phone"], "") + if err != nil { + return nil, err + } + + phoneVerified, err := getBooleanFieldByPath(u, mapping["PhoneVerified"], phone != "") + if err != nil { + return nil, err + } + + data := &UserProvidedData{ + Emails: []Email{ + { + Email: email, + Verified: emailVerified, + Primary: emailPrimary, + }, + }, + Metadata: &Claims{ + Issuer: issuer, + Subject: subject, + Name: name, + FamilyName: familyName, + GivenName: givenName, + MiddleName: middleName, + NickName: nickName, + PreferredUsername: preferredUsername, + Profile: profile, + Picture: picture, + Website: website, + Gender: gender, + Birthdate: birthdate, + ZoneInfo: zoneInfo, + Locale: locale, + UpdatedAt: updatedAt, + Email: email, + EmailVerified: emailVerified, + Phone: phone, + PhoneVerified: phoneVerified, + }, + } + + return data, nil +} + +func getFieldByPath(obj map[string]interface{}, path string, fallback interface{}) (interface{}, error) { + value := obj + + pathParts := strings.Split(path, ".") + for index, field := range pathParts { + fieldValue, ok := value[field] + if !ok { + return fallback, nil + } + + if index == len(pathParts)-1 { + return fieldValue, nil + } + + value = fieldValue.(map[string]interface{}) + } + + return nil, nil +} + +func getStringFieldByPath(obj map[string]interface{}, path string, fallback string) (string, error) { + value, err := getFieldByPath(obj, path, fallback) + if err != nil { + return "", err + } + if result, ok := value.(string); ok { + return result, nil + } else if intValue, ok := value.(int); ok { + return strconv.Itoa(intValue), nil + } else if floatValue, ok := value.(float64); ok { + return strconv.Itoa(int(math.Round(floatValue))), nil + } else if value == nil { + return "", nil + } else { + return "", fmt.Errorf("unable to read field as string: %q %q", path, value) + } +} + +func getBooleanFieldByPath(obj map[string]interface{}, path string, fallback bool) (bool, error) { + value, err := getFieldByPath(obj, path, fallback) + if err != nil { + return false, err + } + if result, ok := value.(bool); ok { + return result, nil + } else { + return false, fmt.Errorf("unable to read field as boolean: %q", path) + } +} + +// NewGenericProvider creates an OAuth provider according to the config specified by the user +func NewGenericProvider(ext conf.GenericOAuthProviderConfiguration, scopes string) (OAuthProvider, error) { + if err := ext.ValidateOAuth(); err != nil { + return nil, err + } + + oauthScopes := strings.Split(scopes, ",") + + return &genericProvider{ + Config: &oauth2.Config{ + ClientID: ext.ClientID[0], + ClientSecret: ext.Secret, + Endpoint: oauth2.Endpoint{ + AuthURL: ext.AuthURL, + TokenURL: ext.TokenURL, + }, + RedirectURL: ext.RedirectURI, + Scopes: oauthScopes, + }, + Issuer: ext.Issuer, + UserInfoURL: ext.UserInfoURL, + UserDataMapping: ext.UserDataMapping, + }, nil +} diff --git a/internal/api/samlacs.go b/internal/api/samlacs.go index 0916a7235..7c4932fe4 100644 --- a/internal/api/samlacs.go +++ b/internal/api/samlacs.go @@ -157,7 +157,7 @@ func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error { logentry.Warn("SAML Metadata for identity provider will expire soon! Update its metadata_xml!") } - } else if *ssoProvider.SAMLProvider.MetadataURL != "" && IsSAMLMetadataStale(idpMetadata, ssoProvider.SAMLProvider) { + } else if *ssoProvider.SAMLProvider.MetadataURL != "" && IsSAMLMetadataStale(idpMetadata, *ssoProvider.SAMLProvider) { rawMetadata, err := fetchSAMLMetadata(ctx, *ssoProvider.SAMLProvider.MetadataURL) if err != nil { // Fail silently but raise warning and continue with existing metadata diff --git a/internal/api/sso.go b/internal/api/sso.go index 10034075c..080d3df70 100644 --- a/internal/api/sso.go +++ b/internal/api/sso.go @@ -2,11 +2,10 @@ package api import ( "net/http" + "net/url" - "github.com/crewjam/saml" "github.com/gofrs/uuid" "github.com/supabase/auth/internal/models" - "github.com/supabase/auth/internal/storage" ) type SingleSignOnParams struct { @@ -57,16 +56,6 @@ func (a *API) SingleSignOn(w http.ResponseWriter, r *http.Request) error { if err := validatePKCEParams(codeChallengeMethod, codeChallenge); err != nil { return err } - flowType := getFlowFromChallenge(params.CodeChallenge) - var flowStateID *uuid.UUID - flowStateID = nil - if isPKCEFlow(flowType) { - flowState, err := generateFlowState(db, models.SSOSAML.String(), models.SSOSAML, codeChallengeMethod, codeChallenge, nil) - if err != nil { - return err - } - flowStateID = &flowState.ID - } var ssoProvider *models.SSOProvider @@ -86,48 +75,37 @@ func (a *API) SingleSignOn(w http.ResponseWriter, r *http.Request) error { } } - entityDescriptor, err := ssoProvider.SAMLProvider.EntityDescriptor() - if err != nil { - return internalServerError("Error parsing SAML Metadata for SAML provider").WithInternalError(err) - } - - serviceProvider := a.getSAMLServiceProvider(entityDescriptor, false /* <- idpInitiated */) - - authnRequest, err := serviceProvider.MakeAuthenticationRequest( - serviceProvider.GetSSOBindingLocation(saml.HTTPRedirectBinding), - saml.HTTPRedirectBinding, - saml.HTTPPostBinding, - ) - if err != nil { - return internalServerError("Error creating SAML Authentication Request").WithInternalError(err) - } - - // Some IdPs do not support the use of the `persistent` NameID format, - // and require a different format to be sent to work. - if ssoProvider.SAMLProvider.NameIDFormat != nil { - authnRequest.NameIDPolicy.Format = ssoProvider.SAMLProvider.NameIDFormat - } - - relayState := models.SAMLRelayState{ - SSOProviderID: ssoProvider.ID, - RequestID: authnRequest.ID, - RedirectTo: params.RedirectTo, - FlowStateID: flowStateID, + var authMethod models.AuthenticationMethod + var providerType string + // providerType, authMethod := "", models.AuthenticationMethod + if ssoProvider.OIDCProvider == nil || ssoProvider.OIDCProvider.ClientId == "" { + providerType, authMethod = models.SSOSAML.String(), models.SSOSAML + } else { + providerType, authMethod = models.SSOOIDC.String(), models.SSOOIDC } - if err := db.Transaction(func(tx *storage.Connection) error { - if terr := tx.Create(&relayState); terr != nil { - return internalServerError("Error creating SAML relay state from sign up").WithInternalError(err) + flowType := getFlowFromChallenge(params.CodeChallenge) + var flowStateID *uuid.UUID + flowStateID = nil + if isPKCEFlow(flowType) { + flowState, err := generateFlowState(db, providerType, authMethod, codeChallengeMethod, codeChallenge, nil) + if err != nil { + return err } - - return nil - }); err != nil { - return err + flowStateID = &flowState.ID } - ssoRedirectURL, err := authnRequest.Redirect(relayState.ID.String(), serviceProvider) - if err != nil { - return internalServerError("Error creating SAML authentication request redirect URL").WithInternalError(err) + var ssoRedirectURL *url.URL + if authMethod == models.SSOSAML { + ssoRedirectURL, err = GenerateRedirectWithSAML(a, db, ssoProvider, flowStateID, params) + if err != nil { + return internalServerError("Error creating SAML authentication request redirect URL").WithInternalError(err) + } + } else if authMethod == models.SSOOIDC { + ssoRedirectURL, err = GenerateRedirectWithOIDC(a, db, ssoProvider, flowStateID, params) + if err != nil { + return internalServerError("Error creating OIDC authentication request redirect URL").WithInternalError(err) + } } skipHTTPRedirect := false diff --git a/internal/api/sso_oidc.go b/internal/api/sso_oidc.go new file mode 100644 index 000000000..dcd77f9ad --- /dev/null +++ b/internal/api/sso_oidc.go @@ -0,0 +1,124 @@ +package api + +import ( + "context" + "crypto/rand" + "encoding/base64" + "fmt" + "net/http" + "net/url" + + "github.com/gofrs/uuid" + "github.com/supabase/auth/internal/api/provider" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" +) + +// GenerateRandomState generates a random state string for OAuth2 +func GenerateRandomState(length int) (string, error) { + // Create a byte slice to hold the random bytes + bytes := make([]byte, length) + + // Read random bytes into the slice + if _, err := rand.Read(bytes); err != nil { + return "", err + } + + // Encode the random bytes into a URL-safe base64 string + return base64.URLEncoding.EncodeToString(bytes), nil +} + +func GenerateRedirectWithOIDC(a *API, db *storage.Connection, ssoProvider *models.SSOProvider, flowStateID *uuid.UUID, params *SingleSignOnParams) (*url.URL, error) { + oidcProviderConfig, err := ssoProvider.OIDCProvider.GenericProviderConfig() + if err != nil { + return &url.URL{}, internalServerError("Error creating generic OIDC provider config").WithInternalError(err) + } + + oidcProviderConfig.RedirectURI = fmt.Sprintf("%s/sso/oidc/callback", a.config.API.ExternalURL) + + provider, err := provider.NewGenericProvider(oidcProviderConfig, "openid") + if err != nil { + return &url.URL{}, internalServerError("Error creating generic OIDC provider").WithInternalError(err) + } + + state, err := GenerateRandomState(32) + if err != nil { + return &url.URL{}, internalServerError("Error creating state").WithInternalError(err) + } + + relayState := models.OIDCFlowState{ + SSOProviderID: ssoProvider.ID, + State: state, + RedirectTo: params.RedirectTo, + FlowStateID: flowStateID, + } + + if err := db.Transaction(func(tx *storage.Connection) error { + if terr := tx.Create(&relayState); terr != nil { + return internalServerError("Error creating SAML relay state from sign up").WithInternalError(err) + } + + return nil + }); err != nil { + return &url.URL{}, err + } + + link := provider.AuthCodeURL(state) + + parsedUrl, err := url.Parse(link) + if err != nil { + return &url.URL{}, internalServerError("Error creating generic auth URL").WithInternalError(err) + } + + return parsedUrl, nil +} + +// loadFlowState parses the `state` query parameter as a JWS payload, +// extracting the provider requested +func (a *API) loadSSOOIDCFlowState(w http.ResponseWriter, r *http.Request) (context.Context, error) { + var state string + if r.Method == http.MethodPost { + state = r.FormValue("state") + } else { + state = r.URL.Query().Get("state") + } + + if state == "" { + return nil, badRequestError(ErrorCodeBadOAuthCallback, "OAuth state parameter missing") + } + + ctx := r.Context() + oauthToken := r.URL.Query().Get("oauth_token") + if oauthToken != "" { + ctx = withRequestToken(ctx, oauthToken) + } + oauthVerifier := r.URL.Query().Get("oauth_verifier") + if oauthVerifier != "" { + ctx = withOAuthVerifier(ctx, oauthVerifier) + } + return a.loadSSOIDCState(ctx, state) +} + +func (a *API) loadSSOIDCState(ctx context.Context, state string) (context.Context, error) { + db := a.db.WithContext(ctx) + + flowState, err := models.FindOIDCFlowStateByID(db, state) + if err != nil { + return nil, badRequestError(ErrorCodeBadOAuthState, "OAuth callback with invalid state").WithInternalError(err) + } + + ctx = withFlowStateID(ctx, flowState.FlowStateID.String()) + + ssoProvider, err := models.FindSSOProviderByID(db, flowState.SSOProviderID) + if err != nil { + return nil, badRequestError(ErrorCodeBadOAuthState, "OAuth callback provider not found").WithInternalError(err) + } + config, err := ssoProvider.OIDCProvider.GenericProviderConfig() + + config.RedirectURI = fmt.Sprintf("%s/sso/oidc/callback", a.config.API.ExternalURL) + + ctx = withGenericProviderConfig(ctx, &config) + ctx = withExternalProviderType(ctx, "sso/oidc") + + return ctx, err +} diff --git a/internal/api/sso_saml.go b/internal/api/sso_saml.go new file mode 100644 index 000000000..b27b6e69c --- /dev/null +++ b/internal/api/sso_saml.go @@ -0,0 +1,55 @@ +package api + +import ( + "net/url" + + "github.com/crewjam/saml" + "github.com/gofrs/uuid" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" +) + +func GenerateRedirectWithSAML(a *API, db *storage.Connection, ssoProvider *models.SSOProvider, flowStateID *uuid.UUID, params *SingleSignOnParams) (*url.URL, error) { + entityDescriptor, err := ssoProvider.SAMLProvider.EntityDescriptor() + if err != nil { + return &url.URL{}, internalServerError("Error parsing SAML Metadata for SAML provider").WithInternalError(err) + } + + serviceProvider := a.getSAMLServiceProvider(entityDescriptor, false /* <- idpInitiated */) + + authnRequest, err := serviceProvider.MakeAuthenticationRequest( + serviceProvider.GetSSOBindingLocation(saml.HTTPRedirectBinding), + saml.HTTPRedirectBinding, + saml.HTTPPostBinding, + ) + if err != nil { + return &url.URL{}, internalServerError("Error creating SAML Authentication Request").WithInternalError(err) + } + + // Some IdPs do not support the use of the `persistent` NameID format, + // and require a different format to be sent to work. + if ssoProvider.SAMLProvider.NameIDFormat != nil { + authnRequest.NameIDPolicy.Format = ssoProvider.SAMLProvider.NameIDFormat + } + + relayState := models.SAMLRelayState{ + SSOProviderID: ssoProvider.ID, + RequestID: authnRequest.ID, + RedirectTo: params.RedirectTo, + FlowStateID: flowStateID, + } + + if err := db.Transaction(func(tx *storage.Connection) error { + if terr := tx.Create(&relayState); terr != nil { + return internalServerError("Error creating SAML relay state from sign up").WithInternalError(err) + } + + return nil + }); err != nil { + return &url.URL{}, err + } + + ssoRedirectURL, err := authnRequest.Redirect(relayState.ID.String(), serviceProvider) + + return ssoRedirectURL, err +} diff --git a/internal/api/ssoadmin.go b/internal/api/ssoadmin.go index 20fd8b9c5..d8c72a820 100644 --- a/internal/api/ssoadmin.go +++ b/internal/api/ssoadmin.go @@ -223,7 +223,7 @@ func (a *API) adminSSOProvidersCreate(w http.ResponseWriter, r *http.Request) er provider := &models.SSOProvider{ // TODO handle Name, Description, Attribute Mapping - SAMLProvider: models.SAMLProvider{ + SAMLProvider: &models.SAMLProvider{ EntityID: metadata.EntityID, MetadataXML: string(rawMetadata), }, @@ -390,7 +390,7 @@ func (a *API) adminSSOProvidersUpdate(w http.ResponseWriter, r *http.Request) er } if updateAttributeMapping || updateSAMLProvider { - if terr := tx.Eager().Update(&provider.SAMLProvider); terr != nil { + if terr := tx.Eager().Update(provider.SAMLProvider); terr != nil { return terr } } diff --git a/internal/api/ssooidcadmin.go b/internal/api/ssooidcadmin.go new file mode 100644 index 000000000..1fbca2233 --- /dev/null +++ b/internal/api/ssooidcadmin.go @@ -0,0 +1,476 @@ +package api + +import ( + "context" + "encoding/json" + "fmt" + "log" + "net/http" + + "github.com/go-chi/chi/v5" + "github.com/gofrs/uuid" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/observability" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/utilities" +) + +// loadSSOProvider looks for an idp_id parameter in the URL route and loads the SSO provider +// with that ID (or resource ID) and adds it to the context. +func (a *API) loadOIDCSSOProvider(w http.ResponseWriter, r *http.Request) (context.Context, error) { + ctx := r.Context() + db := a.db.WithContext(ctx) + + idpParam := chi.URLParam(r, "idp_id") + + idpID, err := uuid.FromString(idpParam) + if err != nil { + // idpParam is not UUIDv4 + return nil, notFoundError(ErrorCodeSSOProviderNotFound, "SSO Identity Provider not found") + } + + // idpParam is a UUIDv4 + provider, err := models.FindSSOProviderByID(db, idpID) + if err != nil { + if models.IsNotFoundError(err) { + return nil, notFoundError(ErrorCodeSSOProviderNotFound, "SSO Identity Provider not found") + } else { + return nil, internalServerError("Database error finding SSO Identity Provider").WithInternalError(err) + } + } + + observability.LogEntrySetField(r, "sso_provider_id", provider.ID.String()) + + return withSSOProvider(r.Context(), provider), nil +} + +// adminSSOProvidersList lists all SAML SSO Identity Providers in the system. Does +// not deal with pagination at this time. +func (a *API) adminOIDCSSOProvidersList(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + + providers, err := models.FindAllSAMLProviders(db) + if err != nil { + return err + } + + for i := range providers { + // remove metadata XML so that the returned JSON is not ginormous + providers[i].SAMLProvider.MetadataXML = "" + } + + return sendJSON(w, http.StatusOK, map[string]interface{}{ + "items": providers, + }) +} + +type CreateOIDCSSOProviderParams struct { + Type string `json:"type"` + + ClientId string `json:"client_id"` + Secret string `json:"secret"` + AuthURL string `json:"auth_url"` + TokenURL string `json:"token_url"` + UserinfoURL string `json:"userinfo_url"` + // MetadataURL string `json:"metadata_url"` + // MetadataXML string `json:"metadata_xml"` + + DiscoveryURL string `json:"discover_url"` + + Domains []string `json:"domains"` + AttributeMapping models.UserDataMapping `json:"attribute_mapping"` + // NameIDFormat string `json:"name_id_format"` +} + +func (p *CreateOIDCSSOProviderParams) validate(forUpdate bool) error { + if !forUpdate && p.Type != "oidc" { + return badRequestError(ErrorCodeValidationFailed, "Only 'oidc' supported for SSO provider type") + } + // } else if p.MetadataURL != "" && p.MetadataXML != "" { + // return badRequestError(ErrorCodeValidationFailed, "Only one of metadata_xml or metadata_url needs to be set") + // } else if !forUpdate && p.MetadataURL == "" && p.MetadataXML == "" { + // return badRequestError(ErrorCodeValidationFailed, "Either metadata_xml or metadata_url must be set") + // } else if p.MetadataURL != "" { + // metadataURL, err := url.ParseRequestURI(p.MetadataURL) + // if err != nil { + // return badRequestError(ErrorCodeValidationFailed, "metadata_url is not a valid URL") + // } + + // if metadataURL.Scheme != "https" { + // return badRequestError(ErrorCodeValidationFailed, "metadata_url is not a HTTPS URL") + // } + // } + + // switch p.NameIDFormat { + // case "", + // string(saml.PersistentNameIDFormat), + // string(saml.EmailAddressNameIDFormat), + // string(saml.TransientNameIDFormat), + // string(saml.UnspecifiedNameIDFormat): + // // it's valid + + // default: + // return badRequestError(ErrorCodeValidationFailed, "name_id_format must be unspecified or one of %v", strings.Join([]string{ + // string(saml.PersistentNameIDFormat), + // string(saml.EmailAddressNameIDFormat), + // string(saml.TransientNameIDFormat), + // string(saml.UnspecifiedNameIDFormat), + // }, ", ")) + // } + + return nil +} + +func (p *CreateOIDCSSOProviderParams) metadata(ctx context.Context) (*conf.GenericOAuthProviderConfiguration, error) { + var discover *OIDCDiscoveryResponse + var err error + + var config *conf.GenericOAuthProviderConfiguration + + if p.DiscoveryURL != "" { + discover, err = fetchOIDCMetadata(ctx, p.DiscoveryURL) + if err != nil { + return nil, err + } + config = &conf.GenericOAuthProviderConfiguration{ + OAuthProviderConfiguration: &conf.OAuthProviderConfiguration{ + ClientID: []string{p.ClientId}, + Secret: p.Secret, + URL: discover.Issuer, + ApiURL: discover.UserInfoEndpoint, + RedirectURI: "", // TODO: figure out how to get the data + }, + Issuer: discover.Issuer, + AuthURL: discover.AuthorizationEndpoint, + TokenURL: discover.TokenEndpoint, + UserInfoURL: discover.UserInfoEndpoint, + UserDataMapping: p.AttributeMapping.Keys, + } + + log.Println(p.AttributeMapping) + } else if p.DiscoveryURL == "" && true { + config = &conf.GenericOAuthProviderConfiguration{ + OAuthProviderConfiguration: &conf.OAuthProviderConfiguration{}, + } + } else { + // impossible situation if you called validate() prior + return nil, nil + } + + // metadata, err := parseSAMLMetadata(rawMetadata) + // if err != nil { + // return nil, err + // } + + return config, nil +} + +// func parseSAMLMetadata(rawMetadata []byte) (*saml.EntityDescriptor, error) { +// if !utf8.Valid(rawMetadata) { +// return nil, badRequestError(ErrorCodeValidationFailed, "SAML Metadata XML contains invalid UTF-8 characters, which are not supported at this time") +// } + +// metadata, err := samlsp.ParseMetadata(rawMetadata) +// if err != nil { +// return nil, err +// } + +// if metadata.EntityID == "" { +// return nil, badRequestError(ErrorCodeValidationFailed, "SAML Metadata does not contain an EntityID") +// } + +// if len(metadata.IDPSSODescriptors) < 1 { +// return nil, badRequestError(ErrorCodeValidationFailed, "SAML Metadata does not contain any IDPSSODescriptor") +// } + +// if len(metadata.IDPSSODescriptors) > 1 { +// return nil, badRequestError(ErrorCodeValidationFailed, "SAML Metadata contains multiple IDPSSODescriptors") +// } + +// return metadata, nil +// } + +type OIDCDiscoveryResponse struct { + Issuer string `json:"issuer"` + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + UserInfoEndpoint string `json:"userinfo_endpoint"` + JWKSURI string `json:"jwks_uri"` + ScopesSupported []string `json:"scopes_supported"` + ResponseTypesSupported []string `json:"response_types_supported"` +} + +func fetchOIDCMetadata(ctx context.Context, issuerURL string) (*OIDCDiscoveryResponse, error) { + // Construct the well-known URL + discoveryURL := fmt.Sprintf("%s/.well-known/openid-configuration", issuerURL) + + req, err := http.NewRequest(http.MethodGet, discoveryURL, nil) + if err != nil { + return nil, internalServerError("Unable to create a request to metadata_url").WithInternalError(err) + } + + req = req.WithContext(ctx) + + // req.Header.Set("Accept", "application/xml;charset=UTF-8") + req.Header.Set("Accept-Charset", "UTF-8") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + + defer utilities.SafeClose(resp.Body) + if resp.StatusCode != http.StatusOK { + return nil, badRequestError(ErrorCodeSAMLMetadataFetchFailed, "HTTP %v error fetching OIDC Metadata from URL '%s'", resp.StatusCode, issuerURL) + } + + // Decode the JSON response into a struct + var config OIDCDiscoveryResponse + if err := json.NewDecoder(resp.Body).Decode(&config); err != nil { + return nil, err + } + + return &config, nil +} + +// adminSSOProvidersCreate creates a new SAML Identity Provider in the system. +func (a *API) adminOIDCSSOProvidersCreate(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + + params := &CreateOIDCSSOProviderParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + + if err := params.validate(false /* <- forUpdate */); err != nil { + return err + } + + log.Println("20") + config, err := params.metadata(ctx) + if err != nil { + return err + } + + log.Println("21") + existingProvider, err := models.FindOIDCProviderByEntityID(db, params.ClientId, params.AuthURL) + if err != nil && !models.IsNotFoundError(err) { + return err + } + log.Println("22") + if existingProvider != nil { + return unprocessableEntityError(ErrorCodeSAMLIdPAlreadyExists, "OIDC Identity Provider with this ClientID (%s) and AuthURL (%s) already exists", params.ClientId, params.AuthURL) + } + log.Println("23") + provider := &models.SSOProvider{ + // TODO handle Name, Description, Attribute Mapping + SAMLProvider: nil, + OIDCProvider: &models.OIDCProvider{ + Issuer: config.Issuer, + ClientId: config.ClientID[0], + AuthURL: config.AuthURL, + TokenURL: config.TokenURL, + UserInfoURL: config.UserInfoURL, + Secret: config.Secret, + RedirectURI: config.RedirectURI, + AttributeMapping: models.UserDataMapping{Keys: config.UserDataMapping}, + }, + } + log.Println("24") + + // if params.MetadataURL != "" { + // provider.SAMLProvider.MetadataURL = ¶ms.MetadataURL + // } + + // if params.NameIDFormat != "" { + // provider.SAMLProvider.NameIDFormat = ¶ms.NameIDFormat + // } + + // provider.SAMLProvider.AttributeMapping = params.AttributeMapping + + for _, domain := range params.Domains { + existingProvider, err := models.FindSSOProviderByDomain(db, domain) + if err != nil && !models.IsNotFoundError(err) { + return err + } + if existingProvider != nil { + return badRequestError(ErrorCodeSSODomainAlreadyExists, "SSO Domain '%s' is already assigned to an SSO identity provider (%s)", domain, existingProvider.ID.String()) + } + + provider.SSODomains = append(provider.SSODomains, models.SSODomain{ + Domain: domain, + }) + } + log.Println("25") + + if err := db.Transaction(func(tx *storage.Connection) error { + + if terr := tx.Eager().Create(provider); terr != nil { + return terr + } + + return tx.Eager().Load(provider) + }); err != nil { + return err + } + log.Println("26") + + return sendJSON(w, http.StatusCreated, provider) +} + +// adminSSOProvidersGet returns an existing SAML Identity Provider in the system. +func (a *API) adminOIDCSSOProvidersGet(w http.ResponseWriter, r *http.Request) error { + provider := getSSOProvider(r.Context()) + + return sendJSON(w, http.StatusOK, provider) +} + +// adminSSOProvidersUpdate updates a provider with the provided diff values. +// func (a *API) adminOIDCSSOProvidersUpdate(w http.ResponseWriter, r *http.Request) error { +// ctx := r.Context() +// db := a.db.WithContext(ctx) + +// params := &CreateSSOProviderParams{} +// if err := retrieveRequestParams(r, params); err != nil { +// return err +// } + +// if err := params.validate(true /* <- forUpdate */); err != nil { +// return err +// } + +// modified := false +// updateSAMLProvider := false + +// provider := getSSOProvider(ctx) + +// if params.MetadataXML != "" || params.MetadataURL != "" { +// // metadata is being updated +// rawMetadata, metadata, err := params.metadata(ctx) +// if err != nil { +// return err +// } + +// if provider.SAMLProvider.EntityID != metadata.EntityID { +// return badRequestError(ErrorCodeSAMLEntityIDMismatch, "SAML Metadata can be updated only if the EntityID matches for the provider; expected '%s' but got '%s'", provider.SAMLProvider.EntityID, metadata.EntityID) +// } + +// if params.MetadataURL != "" { +// provider.SAMLProvider.MetadataURL = ¶ms.MetadataURL +// } + +// provider.SAMLProvider.MetadataXML = string(rawMetadata) +// updateSAMLProvider = true +// modified = true +// } + +// // domains are being "updated" only when params.Domains is not nil, if +// // it was nil (but not `[]`) then the caller is expecting not to modify +// // the domains +// updateDomains := params.Domains != nil + +// var createDomains, deleteDomains []models.SSODomain +// keepDomains := make(map[string]bool) + +// for _, domain := range params.Domains { +// existingProvider, err := models.FindSSOProviderByDomain(db, domain) +// if err != nil && !models.IsNotFoundError(err) { +// return err +// } +// if existingProvider != nil { +// if existingProvider.ID == provider.ID { +// keepDomains[domain] = true +// } else { +// return badRequestError(ErrorCodeSSODomainAlreadyExists, "SSO domain '%s' already assigned to another provider (%s)", domain, existingProvider.ID.String()) +// } +// } else { +// modified = true +// createDomains = append(createDomains, models.SSODomain{ +// Domain: domain, +// SSOProviderID: provider.ID, +// }) +// } +// } + +// if updateDomains { +// for i, domain := range provider.SSODomains { +// if !keepDomains[domain.Domain] { +// modified = true +// deleteDomains = append(deleteDomains, provider.SSODomains[i]) +// } +// } +// } + +// updateAttributeMapping := false +// if params.AttributeMapping.Keys != nil { +// updateAttributeMapping = !provider.SAMLProvider.AttributeMapping.Equal(¶ms.AttributeMapping) +// if updateAttributeMapping { +// modified = true +// provider.SAMLProvider.AttributeMapping = params.AttributeMapping +// } +// } + +// nameIDFormat := "" +// if provider.SAMLProvider.NameIDFormat != nil { +// nameIDFormat = *provider.SAMLProvider.NameIDFormat +// } + +// if params.NameIDFormat != nameIDFormat { +// modified = true + +// if params.NameIDFormat == "" { +// provider.SAMLProvider.NameIDFormat = nil +// } else { +// provider.SAMLProvider.NameIDFormat = ¶ms.NameIDFormat +// } +// } + +// if modified { +// if err := db.Transaction(func(tx *storage.Connection) error { +// if terr := tx.Eager().Update(provider); terr != nil { +// return terr +// } + +// if updateDomains { +// if terr := tx.Destroy(deleteDomains); terr != nil { +// return terr +// } + +// if terr := tx.Eager().Create(createDomains); terr != nil { +// return terr +// } +// } + +// if updateAttributeMapping || updateSAMLProvider { +// if terr := tx.Eager().Update(&provider.SAMLProvider); terr != nil { +// return terr +// } +// } + +// return tx.Eager().Load(provider) +// }); err != nil { +// return unprocessableEntityError(ErrorCodeConflict, "Updating SSO provider failed, likely due to a conflict. Try again?").WithInternalError(err) +// } +// } + +// return sendJSON(w, http.StatusOK, provider) +// } + +// adminSSOProvidersDelete deletes a SAML identity provider. +func (a *API) adminOIDCSSOProvidersDelete(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + + provider := getSSOProvider(ctx) + + if err := db.Transaction(func(tx *storage.Connection) error { + return tx.Eager().Destroy(provider) + }); err != nil { + return err + } + + return sendJSON(w, http.StatusOK, provider) +} diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index 21216fedb..60b32b2f3 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -65,6 +65,16 @@ type OAuthProviderConfiguration struct { SkipNonceCheck bool `json:"skip_nonce_check" split_words:"true"` } +// GenericOAuthProviderConfiguration holds all config related to generic OAuth providers. +type GenericOAuthProviderConfiguration struct { + *OAuthProviderConfiguration + AuthURL string `json:"auth_url" envconfig:"AUTH_URL"` + TokenURL string `json:"token_url" envconfig:"TOKEN_URL"` + Issuer string `json:"issuer"` + UserInfoURL string `json:"userinfo_url" split_words:"true"` + UserDataMapping map[string]string `json:"mapping" split_words:"true"` +} + type AnonymousProviderConfiguration struct { Enabled bool `json:"enabled" default:"false"` } diff --git a/internal/models/factor.go b/internal/models/factor.go index 7c6f6dd30..58dd81581 100644 --- a/internal/models/factor.go +++ b/internal/models/factor.go @@ -49,6 +49,7 @@ const ( EmailChange TokenRefresh Anonymous + SSOOIDC ) func (authMethod AuthenticationMethod) String() string { @@ -67,6 +68,8 @@ func (authMethod AuthenticationMethod) String() string { return "invite" case SSOSAML: return "sso/saml" + case SSOOIDC: + return "sso/oidc" case MagicLink: return "magiclink" case EmailSignup: @@ -102,6 +105,8 @@ func ParseAuthenticationMethod(authMethod string) (AuthenticationMethod, error) return Invite, nil case "sso/saml": return SSOSAML, nil + case "sso/oidc": + return SSOOIDC, nil case "magiclink": return MagicLink, nil case "email/signup": diff --git a/internal/models/sso.go b/internal/models/sso.go index 28c2429ac..1cff03a61 100644 --- a/internal/models/sso.go +++ b/internal/models/sso.go @@ -4,6 +4,7 @@ import ( "database/sql" "database/sql/driver" "encoding/json" + "log" "reflect" "strings" "time" @@ -12,14 +13,16 @@ import ( "github.com/crewjam/saml/samlsp" "github.com/gofrs/uuid" "github.com/pkg/errors" + "github.com/supabase/auth/internal/conf" "github.com/supabase/auth/internal/storage" ) type SSOProvider struct { ID uuid.UUID `db:"id" json:"id"` - SAMLProvider SAMLProvider `has_one:"saml_providers" fk_id:"sso_provider_id" json:"saml,omitempty"` - SSODomains []SSODomain `has_many:"sso_domains" fk_id:"sso_provider_id" json:"domains"` + SAMLProvider *SAMLProvider `has_one:"saml_providers" fk_id:"sso_provider_id" json:"saml,omitempty"` + OIDCProvider *OIDCProvider `has_one:"oidc_providers" fk_id:"sso_provider_id" json:"oidc,omitempty"` + SSODomains []SSODomain `has_many:"sso_domains" fk_id:"sso_provider_id" json:"domains"` CreatedAt time.Time `db:"created_at" json:"created_at"` UpdatedAt time.Time `db:"updated_at" json:"updated_at"` @@ -127,14 +130,104 @@ type SAMLProvider struct { UpdatedAt time.Time `db:"updated_at" json:"-"` } +type UserDataMapping struct { + Keys map[string]string `json:"keys,omitempty"` +} + +func (m *UserDataMapping) Scan(src interface{}) error { + b, ok := src.([]byte) + if !ok { + return errors.New("scan source was not []byte") + } + err := json.Unmarshal(b, m) + if err != nil { + return err + } + return nil +} + +func (m UserDataMapping) Value() (driver.Value, error) { + b, err := json.Marshal(m) + if err != nil { + return nil, err + } + return string(b), nil +} + +type OIDCProvider struct { + ID uuid.UUID `db:"id" json:"-"` + + SSOProvider *SSOProvider `belongs_to:"sso_providers" json:"-"` + SSOProviderID uuid.UUID `db:"sso_provider_id" json:"-"` + + Issuer string `db:"issuer" json:"issuer"` + ClientId string `db:"client_id" json:"client_id"` + Secret string `db:"secret" json:"secret"` + AuthURL string `db:"auth_url" json:"auth_url"` + TokenURL string `db:"token_url" json:"token_url"` + UserInfoURL string `db:"userinfo_url" json:"userinfo_url"` + + RedirectURI string `db:"redirect_uri" json:"redirect_uri"` + // MetadataXML string `db:"metadata_xml" json:"metadata_xml,omitempty"` + // MetadataURL *string `db:"metadata_url" json:"metadata_url,omitempty"` + + AttributeMapping UserDataMapping `db:"attribute_mapping" json:"attribute_mapping,omitempty"` + + // NameIDFormat *string `db:"name_id_format" json:"name_id_format,omitempty"` + + CreatedAt time.Time `db:"created_at" json:"-"` + UpdatedAt time.Time `db:"updated_at" json:"-"` +} + func (p SAMLProvider) TableName() string { return "saml_providers" } +func (p OIDCProvider) TableName() string { + return "oidc_providers" +} + func (p SAMLProvider) EntityDescriptor() (*saml.EntityDescriptor, error) { return samlsp.ParseMetadata([]byte(p.MetadataXML)) } +func (p OIDCProvider) GenericProviderConfig() (conf.GenericOAuthProviderConfiguration, error) { + log.Println("11") + + // Initialize OAuthProviderConfiguration with proper fields + oauthConfig := &conf.OAuthProviderConfiguration{ + ClientID: []string{p.ClientId}, // assuming p.ClientId is correct + Secret: p.Secret, //"ZIttFqNAGsEWG4ZGYshk3dbYNe0m496E", // assuming p.Secret exists + RedirectURI: "", // assuming p.RedirectURI exists + URL: p.Issuer, // assuming p.URL exists + ApiURL: p.UserInfoURL, // assuming p.ApiURL exists + Enabled: true, // assuming p.Enabled exists + SkipNonceCheck: true, // assuming p.SkipNonceCheck exists + } + + // Initialize GenericOAuthProviderConfiguration with oauthConfig + providerConfig := conf.GenericOAuthProviderConfiguration{ + OAuthProviderConfiguration: oauthConfig, + AuthURL: p.AuthURL, // assuming p.AuthURL exists + TokenURL: p.TokenURL, // assuming p.TokenURL exists + Issuer: p.Issuer, + UserInfoURL: p.UserInfoURL, + UserDataMapping: p.AttributeMapping.Keys, /*[string]string{ + "Subject": "sub", + "Email": "email", + "EmailVerified": "email_verified", + }*/ // assuming p.UserDataMapping exists + } + + return providerConfig, nil + + // // Pass the providerConfig to NewGenericProvider + // provider, err := provider.NewGenericProvider(providerConfig, "oidc") + + // log.Println("12") + // return provider, err +} + type SSODomain struct { ID uuid.UUID `db:"id" json:"-"` @@ -167,10 +260,29 @@ type SAMLRelayState struct { FlowState *FlowState `db:"-" json:"flow_state,omitempty" belongs_to:"flow_state"` } +type OIDCFlowState struct { + ID uuid.UUID `db:"id"` + + SSOProviderID uuid.UUID `db:"sso_provider_id"` + + State string `db:"state"` + + RedirectTo string `db:"redirect_to"` + + CreatedAt time.Time `db:"created_at" json:"-"` + UpdatedAt time.Time `db:"updated_at" json:"-"` + FlowStateID *uuid.UUID `db:"flow_state_id" json:"flow_state_id,omitempty"` + FlowState *FlowState `db:"-" json:"flow_state,omitempty" belongs_to:"flow_state"` +} + func (s SAMLRelayState) TableName() string { return "saml_relay_states" } +func (s OIDCFlowState) TableName() string { + return "oidc_relay_states" +} + func FindSAMLProviderByEntityID(tx *storage.Connection, entityId string) (*SSOProvider, error) { var samlProvider SAMLProvider if err := tx.Q().Where("entity_id = ?", entityId).First(&samlProvider); err != nil { @@ -189,6 +301,24 @@ func FindSAMLProviderByEntityID(tx *storage.Connection, entityId string) (*SSOPr return &ssoProvider, nil } +func FindOIDCProviderByEntityID(tx *storage.Connection, clientId string, authUrl string) (*SSOProvider, error) { + var samlProvider OIDCProvider + if err := tx.Q().Where("client_id = ?", clientId).Where("auth_url = ?", clientId).First(&samlProvider); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, SSOProviderNotFoundError{} + } + + return nil, errors.Wrap(err, "error finding SAML SSO provider by EntityID") + } + + var ssoProvider SSOProvider + if err := tx.Eager().Q().Where("id = ?", samlProvider.SSOProviderID).First(&ssoProvider); err != nil { + return nil, errors.Wrap(err, "error finding SAML SSO provider by ID (via EntityID)") + } + + return &ssoProvider, nil +} + func FindSSOProviderByID(tx *storage.Connection, id uuid.UUID) (*SSOProvider, error) { var ssoProvider SSOProvider @@ -247,6 +377,20 @@ func FindAllSAMLProviders(tx *storage.Connection) ([]SSOProvider, error) { return providers, nil } +func FindAllOIDCProviders(tx *storage.Connection) ([]SSOProvider, error) { + var providers []SSOProvider + + if err := tx.Eager().All(&providers); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, nil + } + + return nil, errors.Wrap(err, "error loading all OIDC SSO providers") + } + + return providers, nil +} + func FindSAMLRelayStateByID(tx *storage.Connection, id uuid.UUID) (*SAMLRelayState, error) { var state SAMLRelayState @@ -260,3 +404,18 @@ func FindSAMLRelayStateByID(tx *storage.Connection, id uuid.UUID) (*SAMLRelaySta return &state, nil } + +func FindOIDCFlowStateByID(tx *storage.Connection, stateId string) (*OIDCFlowState, error) { + var state OIDCFlowState + log.Println(stateId) + if err := tx.Eager().Q().Where("state = ?", stateId).First(&state); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + log.Println(err) + return nil, SAMLRelayStateNotFoundError{} + } + log.Println(err) + return nil, errors.Wrap(err, "error loading OIDC Flow State") + } + + return &state, nil +} diff --git a/internal/models/sso_test.go b/internal/models/sso_test.go index b6c965630..9993c043a 100644 --- a/internal/models/sso_test.go +++ b/internal/models/sso_test.go @@ -43,7 +43,7 @@ func (ts *SSOTestSuite) TestConstraints() { examples := []exampleSpec{ { Provider: &SSOProvider{ - SAMLProvider: SAMLProvider{ + SAMLProvider: &SAMLProvider{ EntityID: "", MetadataXML: "", }, @@ -51,7 +51,7 @@ func (ts *SSOTestSuite) TestConstraints() { }, { Provider: &SSOProvider{ - SAMLProvider: SAMLProvider{ + SAMLProvider: &SAMLProvider{ EntityID: "https://example.com/saml/metadata", MetadataXML: "", }, @@ -59,7 +59,7 @@ func (ts *SSOTestSuite) TestConstraints() { }, { Provider: &SSOProvider{ - SAMLProvider: SAMLProvider{ + SAMLProvider: &SAMLProvider{ EntityID: "https://example.com/saml/metadata", MetadataXML: "", }, @@ -79,7 +79,7 @@ func (ts *SSOTestSuite) TestConstraints() { func (ts *SSOTestSuite) TestDomainUniqueness() { require.NoError(ts.T(), ts.db.Eager().Create(&SSOProvider{ - SAMLProvider: SAMLProvider{ + SAMLProvider: &SAMLProvider{ EntityID: "https://example.com/saml/metadata1", MetadataXML: "", }, @@ -91,7 +91,7 @@ func (ts *SSOTestSuite) TestDomainUniqueness() { })) require.Error(ts.T(), ts.db.Eager().Create(&SSOProvider{ - SAMLProvider: SAMLProvider{ + SAMLProvider: &SAMLProvider{ EntityID: "https://example.com/saml/metadata2", MetadataXML: "", }, @@ -105,7 +105,7 @@ func (ts *SSOTestSuite) TestDomainUniqueness() { func (ts *SSOTestSuite) TestEntityIDUniqueness() { require.NoError(ts.T(), ts.db.Eager().Create(&SSOProvider{ - SAMLProvider: SAMLProvider{ + SAMLProvider: &SAMLProvider{ EntityID: "https://example.com/saml/metadata", MetadataXML: "", }, @@ -117,7 +117,7 @@ func (ts *SSOTestSuite) TestEntityIDUniqueness() { })) require.Error(ts.T(), ts.db.Eager().Create(&SSOProvider{ - SAMLProvider: SAMLProvider{ + SAMLProvider: &SAMLProvider{ EntityID: "https://example.com/saml/metadata", MetadataXML: "", }, @@ -131,7 +131,7 @@ func (ts *SSOTestSuite) TestEntityIDUniqueness() { func (ts *SSOTestSuite) TestFindSSOProviderForEmailAddress() { provider := &SSOProvider{ - SAMLProvider: SAMLProvider{ + SAMLProvider: &SAMLProvider{ EntityID: "https://example.com/saml/metadata", MetadataXML: "", }, @@ -182,7 +182,7 @@ func (ts *SSOTestSuite) TestFindSSOProviderForEmailAddress() { func (ts *SSOTestSuite) TestFindSAMLProviderByEntityID() { provider := &SSOProvider{ - SAMLProvider: SAMLProvider{ + SAMLProvider: &SAMLProvider{ EntityID: "https://example.com/saml/metadata", MetadataXML: "", }, diff --git a/migrations/20240819081613_add_oidc_sso.up.sql b/migrations/20240819081613_add_oidc_sso.up.sql new file mode 100644 index 000000000..ad34b9d9f --- /dev/null +++ b/migrations/20240819081613_add_oidc_sso.up.sql @@ -0,0 +1,49 @@ +do $$ +begin + create table if not exists {{ index .Options "Namespace" }}.oidc_providers ( + id uuid not null, + sso_provider_id uuid not null, + issuer text not null, + client_id text not null, + secret text not null, + auth_url text not null, + token_url text not null, + userinfo_url text not null, + redirect_uri text not null, + -- metadata_url text null, + attribute_mapping jsonb null, + created_at timestamptz null, + updated_at timestamptz null, + primary key (id), + foreign key (sso_provider_id) references {{ index .Options "Namespace" }}.sso_providers (id) on delete cascade + -- constraint "metadata_xml not empty" check (char_length(metadata_xml) > 0), + -- constraint "metadata_url not empty" check (metadata_url = null or char_length(metadata_url) > 0), + -- constraint "entity_id not empty" check (char_length(entity_id) > 0) + ); + + create index if not exists oidc_providers_sso_provider_id_idx on {{ index .Options "Namespace" }}.oidc_providers (sso_provider_id); + + comment on table {{ index .Options "Namespace" }}.oidc_providers is 'Auth: Manages OIDC Identity Provider connections.'; + + create table if not exists {{ index .Options "Namespace" }}.oidc_relay_states ( + id uuid not null, + sso_provider_id uuid not null, + state text not null, + for_email text null, + redirect_to text null, + created_at timestamptz null, + updated_at timestamptz null, + flow_state_id uuid null, + primary key (id), + foreign key (sso_provider_id) references {{ index .Options "Namespace" }}.sso_providers (id) on delete cascade, + foreign key (flow_state_id) references {{ index .Options "Namespace" }}.flow_state (id) on delete cascade, + constraint "state not empty" check(char_length(state) > 0) + ); + + create index if not exists oidc_relay_states_sso_provider_id_idx on {{ index .Options "Namespace" }}.oidc_relay_states (sso_provider_id); + create index if not exists oidc_relay_states_for_email_idx on {{ index .Options "Namespace" }}.oidc_relay_states (for_email); + + comment on table {{ index .Options "Namespace" }}.oidc_relay_states is 'Auth: Contains OIDC Relay State information for each Service Provider initiated login.'; + + +end $$; From 24e973830bcb38d62d261ef5853b2171a73c8eb3 Mon Sep 17 00:00:00 2001 From: smndtrl Date: Wed, 21 Aug 2024 09:19:04 +0000 Subject: [PATCH 2/4] cleanup --- internal/models/sso.go | 43 ++++++++++++------------------------------ 1 file changed, 12 insertions(+), 31 deletions(-) diff --git a/internal/models/sso.go b/internal/models/sso.go index 1cff03a61..5526f0edd 100644 --- a/internal/models/sso.go +++ b/internal/models/sso.go @@ -4,7 +4,6 @@ import ( "database/sql" "database/sql/driver" "encoding/json" - "log" "reflect" "strings" "time" @@ -168,13 +167,9 @@ type OIDCProvider struct { UserInfoURL string `db:"userinfo_url" json:"userinfo_url"` RedirectURI string `db:"redirect_uri" json:"redirect_uri"` - // MetadataXML string `db:"metadata_xml" json:"metadata_xml,omitempty"` - // MetadataURL *string `db:"metadata_url" json:"metadata_url,omitempty"` AttributeMapping UserDataMapping `db:"attribute_mapping" json:"attribute_mapping,omitempty"` - // NameIDFormat *string `db:"name_id_format" json:"name_id_format,omitempty"` - CreatedAt time.Time `db:"created_at" json:"-"` UpdatedAt time.Time `db:"updated_at" json:"-"` } @@ -192,40 +187,27 @@ func (p SAMLProvider) EntityDescriptor() (*saml.EntityDescriptor, error) { } func (p OIDCProvider) GenericProviderConfig() (conf.GenericOAuthProviderConfiguration, error) { - log.Println("11") - // Initialize OAuthProviderConfiguration with proper fields oauthConfig := &conf.OAuthProviderConfiguration{ - ClientID: []string{p.ClientId}, // assuming p.ClientId is correct - Secret: p.Secret, //"ZIttFqNAGsEWG4ZGYshk3dbYNe0m496E", // assuming p.Secret exists - RedirectURI: "", // assuming p.RedirectURI exists - URL: p.Issuer, // assuming p.URL exists - ApiURL: p.UserInfoURL, // assuming p.ApiURL exists - Enabled: true, // assuming p.Enabled exists - SkipNonceCheck: true, // assuming p.SkipNonceCheck exists + ClientID: []string{p.ClientId}, + Secret: p.Secret, + RedirectURI: "", + URL: p.Issuer, + ApiURL: p.UserInfoURL, + Enabled: true, + SkipNonceCheck: true, } - // Initialize GenericOAuthProviderConfiguration with oauthConfig providerConfig := conf.GenericOAuthProviderConfiguration{ OAuthProviderConfiguration: oauthConfig, - AuthURL: p.AuthURL, // assuming p.AuthURL exists - TokenURL: p.TokenURL, // assuming p.TokenURL exists + AuthURL: p.AuthURL, + TokenURL: p.TokenURL, Issuer: p.Issuer, UserInfoURL: p.UserInfoURL, - UserDataMapping: p.AttributeMapping.Keys, /*[string]string{ - "Subject": "sub", - "Email": "email", - "EmailVerified": "email_verified", - }*/ // assuming p.UserDataMapping exists + UserDataMapping: p.AttributeMapping.Keys, } return providerConfig, nil - - // // Pass the providerConfig to NewGenericProvider - // provider, err := provider.NewGenericProvider(providerConfig, "oidc") - - // log.Println("12") - // return provider, err } type SSODomain struct { @@ -407,13 +389,12 @@ func FindSAMLRelayStateByID(tx *storage.Connection, id uuid.UUID) (*SAMLRelaySta func FindOIDCFlowStateByID(tx *storage.Connection, stateId string) (*OIDCFlowState, error) { var state OIDCFlowState - log.Println(stateId) + if err := tx.Eager().Q().Where("state = ?", stateId).First(&state); err != nil { if errors.Cause(err) == sql.ErrNoRows { - log.Println(err) return nil, SAMLRelayStateNotFoundError{} } - log.Println(err) + return nil, errors.Wrap(err, "error loading OIDC Flow State") } From 7acf4cd445b29a8a06e1398c693ce1c01b88f0c4 Mon Sep 17 00:00:00 2001 From: smndtrl Date: Wed, 21 Aug 2024 09:27:32 +0000 Subject: [PATCH 3/4] remove log --- internal/api/external.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/internal/api/external.go b/internal/api/external.go index de3ff6e3a..7db0652d3 100644 --- a/internal/api/external.go +++ b/internal/api/external.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "log" "net/http" "net/url" "strconv" @@ -251,7 +250,6 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re // This means that the callback is using PKCE // Set the flowState.AuthCode to the query param here rurl, err = a.prepPKCERedirectURL(rurl, flowState.AuthCode) - log.Println("rurl", rurl) if err != nil { return err } From ff28d3b99fc9ed7a288a7e2df869d99376ac6f3d Mon Sep 17 00:00:00 2001 From: smndtrl Date: Wed, 21 Aug 2024 11:48:42 +0000 Subject: [PATCH 4/4] saml/oidc enable flags --- internal/api/api.go | 4 +++- internal/api/middleware.go | 18 +++++++++++++++++- internal/api/middleware_test.go | 2 +- internal/conf/configuration.go | 1 + internal/conf/oidc.go | 14 ++++++++++++++ 5 files changed, 36 insertions(+), 3 deletions(-) create mode 100644 internal/conf/oidc.go diff --git a/internal/api/api.go b/internal/api/api.go index 67e02779e..cfff81916 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -258,7 +258,7 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne }) r.Route("/sso", func(r *router) { - r.Use(api.requireSAMLEnabled) + r.Use(api.requireSSOEnabled) r.With(api.limitHandler( // Allow requests at the specified rate per 5 minutes. tollbooth.NewLimiter(api.config.RateLimitSso/(60*5), &limiter.ExpirableOptions{ @@ -267,6 +267,7 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne )).With(api.verifyCaptcha).Post("/", api.SingleSignOn) r.Route("/saml", func(r *router) { + r.Use(api.requireSSOSAMLEnabled) r.Get("/metadata", api.SAMLMetadata) r.With(api.limitHandler( @@ -278,6 +279,7 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne }) r.Route("/oidc", func(r *router) { + r.Use(api.requireSSOOIDCEnabled) r.Route("/callback", func(r *router) { r.Use(api.isValidExternalHost) r.Use(api.loadSSOOIDCFlowState) diff --git a/internal/api/middleware.go b/internal/api/middleware.go index 8caa5d885..edfc23338 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -223,7 +223,7 @@ func (a *API) isValidExternalHost(w http.ResponseWriter, req *http.Request) (con return withExternalHost(ctx, u), nil } -func (a *API) requireSAMLEnabled(w http.ResponseWriter, req *http.Request) (context.Context, error) { +func (a *API) requireSSOSAMLEnabled(w http.ResponseWriter, req *http.Request) (context.Context, error) { ctx := req.Context() if !a.config.SAML.Enabled { return nil, notFoundError(ErrorCodeSAMLProviderDisabled, "SAML 2.0 is disabled") @@ -231,6 +231,22 @@ func (a *API) requireSAMLEnabled(w http.ResponseWriter, req *http.Request) (cont return ctx, nil } +func (a *API) requireSSOOIDCEnabled(w http.ResponseWriter, req *http.Request) (context.Context, error) { + ctx := req.Context() + if !a.config.OIDC.Enabled { + return nil, notFoundError(ErrorCodeSAMLProviderDisabled, "OIDC is disabled") + } + return ctx, nil +} + +func (a *API) requireSSOEnabled(w http.ResponseWriter, req *http.Request) (context.Context, error) { + ctx := req.Context() + if !(a.config.OIDC.Enabled || a.config.SAML.Enabled) { + return nil, notFoundError(ErrorCodeSAMLProviderDisabled, "Either SAML or OIDC for SSO need to be enabled") + } + return ctx, nil +} + func (a *API) requireManualLinkingEnabled(w http.ResponseWriter, req *http.Request) (context.Context, error) { ctx := req.Context() if !a.config.Security.ManualLinkingEnabled { diff --git a/internal/api/middleware_test.go b/internal/api/middleware_test.go index eb8c5da3b..4c6570910 100644 --- a/internal/api/middleware_test.go +++ b/internal/api/middleware_test.go @@ -287,7 +287,7 @@ func (ts *MiddlewareTestSuite) TestRequireSAMLEnabled() { req := httptest.NewRequest("GET", "http://localhost", nil) w := httptest.NewRecorder() - _, err := ts.API.requireSAMLEnabled(w, req) + _, err := ts.API.requireSSOSAMLEnabled(w, req) require.Equal(ts.T(), c.expectedErr, err) }) } diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index 60b32b2f3..e60fcf947 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -273,6 +273,7 @@ type GlobalConfiguration struct { Duration int `json:"duration"` } `json:"cookies"` SAML SAMLConfiguration `json:"saml"` + OIDC OIDCConfiguration `json:"oidc"` CORS CORSConfiguration `json:"cors"` } diff --git a/internal/conf/oidc.go b/internal/conf/oidc.go new file mode 100644 index 000000000..a705e6b68 --- /dev/null +++ b/internal/conf/oidc.go @@ -0,0 +1,14 @@ +package conf + +// OIDCConfiguration holds configuration for native OIDC SSO support. +type OIDCConfiguration struct { + Enabled bool `json:"enabled"` +} + +func (c *OIDCConfiguration) Validate() error { + if c.Enabled { + + } + + return nil +}