Skip to content

Commit 1eb056f

Browse files
committed
Support scram sha 256
2 parents 1b62d33 + 90afe88 commit 1eb056f

File tree

1 file changed

+85
-52
lines changed

1 file changed

+85
-52
lines changed

Database/MongoDB/Query.hs

Lines changed: 85 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ module Database.MongoDB.Query (
1111
-- * Database
1212
Database, allDatabases, useDb, thisDatabase,
1313
-- ** Authentication
14-
Username, Password, auth, authMongoCR, authSCRAMSHA1,
14+
Username, Password, auth, authMongoCR, authSCRAMSHA1, authSCRAMSHA256,
1515
-- * Collection
1616
Collection, allCollections,
1717
-- ** Selection
@@ -61,8 +61,10 @@ import Control.Monad
6161
)
6262
import Control.Monad.Reader (MonadReader, ReaderT, ask, asks, local, runReaderT)
6363
import Control.Monad.Trans (MonadIO, liftIO, lift)
64+
import Control.Monad.Trans.Except
6465
import qualified Crypto.Hash.MD5 as MD5
6566
import qualified Crypto.Hash.SHA1 as SHA1
67+
import qualified Crypto.Hash.SHA256 as SHA256
6668
import qualified Crypto.MAC.HMAC as HMAC
6769
import qualified Crypto.Nonce as Nonce
6870
import Data.Binary.Put (runPut)
@@ -285,62 +287,93 @@ authMongoCR usr pss = do
285287
n <- at "nonce" <$> runCommand ["getnonce" =: (1 :: Int)]
286288
true1 "ok" <$> runCommand ["authenticate" =: (1 :: Int), "user" =: usr, "nonce" =: n, "key" =: pwKey n usr pss]
287289

290+
data HashAlgorithm = SHA1 | SHA256 deriving Show
291+
292+
hash :: HashAlgorithm -> B.ByteString -> B.ByteString
293+
hash SHA1 = SHA1.hash
294+
hash SHA256 = SHA256.hash
295+
288296
authSCRAMSHA1 :: MonadIO m => Username -> Password -> Action m Bool
297+
authSCRAMSHA1 = authSCRAMWith SHA1
298+
299+
authSCRAMSHA256 :: MonadIO m => Username -> Password -> Action m Bool
300+
authSCRAMSHA256 = authSCRAMWith SHA256
301+
302+
toAuthResult :: Functor m => ExceptT String (Action m) () -> Action m Bool
303+
toAuthResult = fmap (either (const False) (const True)) . runExceptT
304+
305+
-- | It should technically perform SASLprep, but the implementation is currently id
306+
saslprep :: Text -> Text
307+
saslprep = id
308+
309+
authSCRAMWith :: MonadIO m => HashAlgorithm -> Username -> Password -> Action m Bool
289310
-- ^ Authenticate with the current database, using the SCRAM-SHA-1 authentication mechanism (default in MongoDB server >= 3.0)
290-
authSCRAMSHA1 un pw = do
291-
let hmac = HMAC.hmac SHA1.hash 64
311+
authSCRAMWith algo un pw = toAuthResult $ do
312+
let hmac = HMAC.hmac (hash algo) 64
292313
nonce <- liftIO (Nonce.withGenerator Nonce.nonce128 <&> B64.encode)
293314
let firstBare = B.concat [B.pack $ "n=" ++ T.unpack un ++ ",r=", nonce]
294-
let client1 = ["saslStart" =: (1 :: Int), "mechanism" =: ("SCRAM-SHA-1" :: String), "payload" =: (B.unpack . B64.encode $ B.concat [B.pack "n,,", firstBare]), "autoAuthorize" =: (1 :: Int)]
295-
server1 <- runCommand client1
296-
297-
shortcircuit (true1 "ok" server1) $ do
298-
let serverPayload1 = B64.decodeLenient . B.pack . at "payload" $ server1
299-
let serverData1 = parseSCRAM serverPayload1
300-
let iterations = read . B.unpack $ Map.findWithDefault "1" "i" serverData1
301-
let salt = B64.decodeLenient $ Map.findWithDefault "" "s" serverData1
302-
let snonce = Map.findWithDefault "" "r" serverData1
303-
304-
shortcircuit (B.isInfixOf nonce snonce) $ do
305-
let withoutProof = B.concat [B.pack "c=biws,r=", snonce]
306-
let digestS = B.pack $ T.unpack un ++ ":mongo:" ++ T.unpack pw
307-
let digest = B16.encode $ MD5.hash digestS
308-
let saltedPass = scramHI digest salt iterations
309-
let clientKey = hmac saltedPass (B.pack "Client Key")
310-
let storedKey = SHA1.hash clientKey
311-
let authMsg = B.concat [firstBare, B.pack ",", serverPayload1, B.pack ",", withoutProof]
312-
let clientSig = hmac storedKey authMsg
313-
let pval = B64.encode . BS.pack $ BS.zipWith xor clientKey clientSig
314-
let clientFinal = B.concat [withoutProof, B.pack ",p=", pval]
315-
let serverKey = hmac saltedPass (B.pack "Server Key")
316-
let serverSig = B64.encode $ hmac serverKey authMsg
317-
let client2 = ["saslContinue" =: (1 :: Int), "conversationId" =: (at "conversationId" server1 :: Int), "payload" =: B.unpack (B64.encode clientFinal)]
318-
server2 <- runCommand client2
319-
320-
shortcircuit (true1 "ok" server2) $ do
321-
let serverPayload2 = B64.decodeLenient . B.pack $ at "payload" server2
322-
let serverData2 = parseSCRAM serverPayload2
323-
let serverSigComp = Map.findWithDefault "" "v" serverData2
324-
325-
shortcircuit (serverSig == serverSigComp) $ do
326-
let done = true1 "done" server2
327-
if done
328-
then return True
329-
else do
330-
let client2Step2 = [ "saslContinue" =: (1 :: Int)
331-
, "conversationId" =: (at "conversationId" server1 :: Int)
332-
, "payload" =: String ""]
333-
server3 <- runCommand client2Step2
334-
shortcircuit (true1 "ok" server3) $ do
335-
return True
336-
where
337-
shortcircuit True f = f
338-
shortcircuit False _ = return False
339-
340-
scramHI :: B.ByteString -> B.ByteString -> Int -> B.ByteString
341-
scramHI digest salt iters = snd $ foldl com (u1, u1) [1..(iters-1)]
315+
let client1 =
316+
[ "saslStart" =: (1 :: Int)
317+
, "mechanism" =: case algo of
318+
SHA1 -> "SCRAM-SHA-1" :: String
319+
SHA256 -> "SCRAM-SHA-256"
320+
, "payload" =: (B.unpack . B64.encode $ B.concat [B.pack "n,,", firstBare])
321+
, "autoAuthorize" =: (1 :: Int)
322+
]
323+
server1 <- lift $ runCommand client1
324+
325+
shortcircuit (true1 "ok" server1) (show server1)
326+
let serverPayload1 = B64.decodeLenient . B.pack . at "payload" $ server1
327+
let serverData1 = parseSCRAM serverPayload1
328+
let iterations = read . B.unpack $ Map.findWithDefault "1" "i" serverData1
329+
let salt = B64.decodeLenient $ Map.findWithDefault "" "s" serverData1
330+
let snonce = Map.findWithDefault "" "r" serverData1
331+
332+
shortcircuit (B.isInfixOf nonce snonce) "nonce"
333+
let withoutProof = B.concat [B.pack "c=biws,r=", snonce]
334+
let digest = case algo of
335+
SHA1 -> B16.encode $ MD5.hash $ B.pack $ T.unpack un ++ ":mongo:" ++ T.unpack pw
336+
SHA256 -> B.pack $ T.unpack $ saslprep pw
337+
let saltedPass = scramHI algo digest salt iterations
338+
let clientKey = hmac saltedPass (B.pack "Client Key")
339+
let storedKey = hash algo clientKey
340+
let authMsg = B.concat [firstBare, B.pack ",", serverPayload1, B.pack ",", withoutProof]
341+
let clientSig = hmac storedKey authMsg
342+
let pval = B64.encode . BS.pack $ BS.zipWith xor clientKey clientSig
343+
let clientFinal = B.concat [withoutProof, B.pack ",p=", pval]
344+
345+
let client2 =
346+
[ "saslContinue" =: (1 :: Int)
347+
, "conversationId" =: (at "conversationId" server1 :: Int)
348+
, "payload" =: B.unpack (B64.encode clientFinal)
349+
]
350+
server2 <- lift $ runCommand client2
351+
shortcircuit (true1 "ok" server2) (show server2)
352+
353+
let serverKey = hmac saltedPass (B.pack "Server Key")
354+
let serverSig = B64.encode $ hmac serverKey authMsg
355+
let serverPayload2 = B64.decodeLenient . B.pack $ at "payload" server2
356+
let serverData2 = parseSCRAM serverPayload2
357+
let serverSigComp = Map.findWithDefault "" "v" serverData2
358+
359+
shortcircuit (serverSig == serverSigComp) "server signature does not match"
360+
if true1 "done" server2
361+
then return ()
362+
else do
363+
let client2Step2 = [ "saslContinue" =: (1 :: Int)
364+
, "conversationId" =: (at "conversationId" server1 :: Int)
365+
, "payload" =: String ""]
366+
server3 <- lift $ runCommand client2Step2
367+
shortcircuit (true1 "ok" server3) "server3"
368+
369+
shortcircuit :: Monad m => Bool -> String -> ExceptT String m ()
370+
shortcircuit True _ = pure ()
371+
shortcircuit False reason = throwE (show reason)
372+
373+
scramHI :: HashAlgorithm -> B.ByteString -> B.ByteString -> Int -> B.ByteString
374+
scramHI algo digest salt iters = snd $ foldl com (u1, u1) [1..(iters-1)]
342375
where
343-
hmacd = HMAC.hmac SHA1.hash 64 digest
376+
hmacd = HMAC.hmac (hash algo) 64 digest
344377
u1 = hmacd (B.concat [salt, BS.pack [0, 0, 0, 1]])
345378
com (u,uc) _ = let u' = hmacd u in (u', BS.pack $ BS.zipWith xor uc u')
346379

0 commit comments

Comments
 (0)