Skip to content

Commit

Permalink
Merge pull request #79 from anton-k/safer-urls
Browse files Browse the repository at this point in the history
Make run-time errors on wrong URL top Server API
  • Loading branch information
anton-k committed Nov 26, 2023
2 parents e7e7fca + 387a30c commit f29ea38
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 5 deletions.
53 changes: 49 additions & 4 deletions mig/src/Mig/Core/Class/Url.hs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@ import Data.Text (Text)
import Data.Text qualified as Text
import GHC.TypeLits
import Mig.Core.Api (Path (..), PathItem (..), flatApi, fromFlatApi)
import Mig.Core.Class.Route (Route (..))
import Mig.Core.Server (Server (..), getServerPaths)
import Mig.Core.Types.Info (RouteInfo, routeHasCapture, routeHasOptionalQuery, routeHasQuery, routeHasQueryFlag)
import Mig.Core.Types.Pair
import Mig.Core.Types.Route
import Safe (headMay)
import Web.HttpApiData

-- | Url-template type.
Expand Down Expand Up @@ -161,46 +164,88 @@ instance ToUrl Url where

instance (KnownSymbol sym, ToHttpApiData a, ToUrl b) => ToUrl (Query sym a -> b) where
toUrl server = \(Query val) ->
mapUrl (insertQuery (getName @sym) (toUrlPiece val)) (toUrl @b server)
whenOrError (hasQuery (getName @sym) server) noQuery $
mapUrl (insertQuery (getName @sym) (toUrlPiece val)) (toUrl @b server)
where
noQuery = noInputMessage ("query with name: " <> getName @sym) server

mapUrl f a = \query -> mapUrl f (a query)
urlArity = urlArity @b

insertQuery :: Text -> Text -> Url -> Url
insertQuery name val url = url{queries = (name, val) : url.queries}

hasQuery :: Text -> Server m -> Bool
hasQuery name = hasInput (routeHasQuery name)

-- optional query

instance (KnownSymbol sym, ToHttpApiData a, ToUrl b) => ToUrl (Optional sym a -> b) where
toUrl server = \(Optional mVal) ->
mapUrl (maybe id (insertQuery (getName @sym) . toUrlPiece) mVal) (toUrl @b server)
whenOrError (hasOptionalQuery (getName @sym) server) noOptionalQuery $
mapUrl (maybe id (insertQuery (getName @sym) . toUrlPiece) mVal) (toUrl @b server)
where
noOptionalQuery = noInputMessage ("optional query with name: " <> getName @sym) server

mapUrl f a = \query -> mapUrl f (a query)
urlArity = urlArity @b

hasOptionalQuery :: Text -> Server m -> Bool
hasOptionalQuery name = hasInput (routeHasOptionalQuery name)

-- query flag

instance (KnownSymbol sym, ToUrl b) => ToUrl (QueryFlag sym -> b) where
toUrl server = \(QueryFlag val) ->
mapUrl (insertQuery (getName @sym) (toUrlPiece val)) (toUrl @b server)
whenOrError (hasQueryFlag (getName @sym) server) noQueryFlag $
mapUrl (insertQuery (getName @sym) (toUrlPiece val)) (toUrl @b server)
where
noQueryFlag = noInputMessage ("query flag with name: " <> getName @sym) server

mapUrl f a = \query -> mapUrl f (a query)
urlArity = urlArity @b

hasQueryFlag :: Text -> Server m -> Bool
hasQueryFlag name = hasInput (routeHasQueryFlag name)

-- capture

instance (KnownSymbol sym, ToHttpApiData a, ToUrl b) => ToUrl (Capture sym a -> b) where
toUrl server = \(Capture val) ->
mapUrl (insertCapture (getName @sym) (toUrlPiece val)) (toUrl @b server)
whenOrError (hasCapture (getName @sym) server) noCapture $
mapUrl (insertCapture (getName @sym) (toUrlPiece val)) (toUrl @b server)
where
noCapture = noInputMessage ("Capture with name: " <> getName @sym) server

mapUrl f a = \capture -> mapUrl f (a capture)
urlArity = urlArity @b

insertCapture :: Text -> Text -> Url -> Url
insertCapture name val url = url{captures = Map.insert name val url.captures}

hasCapture :: Text -> Server m -> Bool
hasCapture name = hasInput (routeHasCapture name)

-------------------------------------------------------------------------------------
-- utils

getName :: forall sym a. (KnownSymbol sym, IsString a) => a
getName = fromString (symbolVal (Proxy @sym))

hasInput :: (RouteInfo -> Bool) -> Server m -> Bool
hasInput check (Server api) =
maybe False (check . (.info) . snd) $ headMay $ flatApi api

noInputMessage :: String -> Server m -> String
noInputMessage item (Server api) =
unlines
[ unwords ["Server has no", item, "at route", route]
, "Check the order of routes on the left side of toUrl expression"
]
where
route = maybe "unknown" (Text.unpack . toUrlPiece . fst) $ headMay (flatApi api)

whenOrError :: Bool -> String -> a -> a
whenOrError cond message a
| cond = a
| otherwise = error message
47 changes: 46 additions & 1 deletion mig/src/Mig/Core/Types/Info.hs
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,17 @@ module Mig.Core.Types.Info (
addQueryFlagInfo,
addOptionalInfo,
addCaptureInfo,

-- * checks
routeHasQuery,
routeHasOptionalQuery,
routeHasQueryFlag,
routeHasCapture,
) where

import Data.List.Extra (firstJust)
import Data.Map.Strict qualified as Map
import Data.OpenApi
import Data.OpenApi (Definitions, Referenced, Schema, ToParamSchema (..), ToSchema (..), declareSchemaRef)
import Data.OpenApi.Declare (runDeclare)
import Data.Proxy
import Data.String
Expand Down Expand Up @@ -208,6 +214,45 @@ addQueryFlagInfo = addRouteInput (QueryFlagInput (getName @sym))
addBodyInfo :: forall ty a. (ToMediaType ty, ToSchema a) => RouteInfo -> RouteInfo
addBodyInfo = addRouteInput (ReqBodyInput (toMediaType @ty) (toSchemaDefs @a))

---------------------------------------------
-- checks

-- | Check that route has query with given name
routeHasQuery :: Text -> RouteInfo -> Bool
routeHasQuery expectedName = routeHasInput isQuery
where
isQuery = \case
QueryInput (IsRequired True) name _ -> expectedName == name
_ -> False

-- | Check that route has query with given name
routeHasOptionalQuery :: Text -> RouteInfo -> Bool
routeHasOptionalQuery expectedName = routeHasInput isOptionalQuery
where
isOptionalQuery = \case
QueryInput (IsRequired False) name _ -> expectedName == name
_ -> False

-- | Check that route has query with given name
routeHasQueryFlag :: Text -> RouteInfo -> Bool
routeHasQueryFlag expectedName = routeHasInput isQueryFlag
where
isQueryFlag = \case
QueryFlagInput name -> expectedName == name
_ -> False

-- | Check that route has query with given name
routeHasCapture :: Text -> RouteInfo -> Bool
routeHasCapture expectedName = routeHasInput isCapture
where
isCapture = \case
CaptureInput name _ -> expectedName == name
_ -> False

-- | Check that route has certain input
routeHasInput :: (RouteInput -> Bool) -> RouteInfo -> Bool
routeHasInput check info = any (check . (.content)) info.inputs

---------------------------------------------
-- utils

Expand Down

0 comments on commit f29ea38

Please sign in to comment.