From 2f683d66c8c276a4acd071c1dfaafd7cc1b0f41e Mon Sep 17 00:00:00 2001 From: Paolo Capriotti Date: Wed, 22 Mar 2023 13:58:54 +0100 Subject: [PATCH] Fix KeyPackage parser --- libs/wire-api/src/Wire/API/MLS/KeyPackage.hs | 10 ++-------- libs/wire-api/src/Wire/API/MLS/ProtocolVersion.hs | 2 +- libs/wire-api/src/Wire/API/MLS/Serialisation.hs | 12 +++++------- 3 files changed, 8 insertions(+), 16 deletions(-) diff --git a/libs/wire-api/src/Wire/API/MLS/KeyPackage.hs b/libs/wire-api/src/Wire/API/MLS/KeyPackage.hs index 019790338fc..ff9b74f83af 100644 --- a/libs/wire-api/src/Wire/API/MLS/KeyPackage.hs +++ b/libs/wire-api/src/Wire/API/MLS/KeyPackage.hs @@ -35,7 +35,6 @@ import Cassandra.CQL hiding (Set) import Control.Applicative import Control.Lens hiding (set, (.=)) import Data.Aeson (FromJSON, ToJSON) -import Data.Binary import Data.Binary.Get import Data.Binary.Put import qualified Data.ByteString as B @@ -167,7 +166,6 @@ data KeyPackageTBS = KeyPackageTBS cipherSuite :: CipherSuite, initKey :: HPKEPublicKey, leafNode :: LeafNode, - credential :: Credential, extensions :: [Extension] } deriving stock (Eq, Show, Generic) @@ -180,7 +178,6 @@ instance ParseMLS KeyPackageTBS where <*> parseMLS <*> parseMLS <*> parseMLS - <*> parseMLS <*> parseMLSVector @VarInt parseMLS data KeyPackage = KeyPackage @@ -201,9 +198,6 @@ instance HasField "cipherSuite" KeyPackage CipherSuite where instance HasField "initKey" KeyPackage HPKEPublicKey where getField = (.tbs.rmValue.initKey) -instance HasField "credential" KeyPackage Credential where - getField = (.tbs.rmValue.credential) - instance HasField "extensions" KeyPackage [Extension] where getField = (.tbs.rmValue.extensions) @@ -211,7 +205,7 @@ instance HasField "leafNode" KeyPackage LeafNode where getField = (.tbs.rmValue.leafNode) keyPackageIdentity :: KeyPackage -> Either Text ClientIdentity -keyPackageIdentity = decodeMLS' @ClientIdentity . (.credential.identityData) +keyPackageIdentity = decodeMLS' @ClientIdentity . (.leafNode.credential.identityData) rawKeyPackageSchema :: ValueSchema NamedSwaggerDoc (RawMLS KeyPackage) rawKeyPackageSchema = @@ -225,7 +219,7 @@ instance ParseMLS KeyPackage where parseMLS = KeyPackage <$> parseRawMLS parseMLS - <*> parseMLSBytes @Word16 + <*> parseMLSBytes @VarInt -------------------------------------------------------------------------------- diff --git a/libs/wire-api/src/Wire/API/MLS/ProtocolVersion.hs b/libs/wire-api/src/Wire/API/MLS/ProtocolVersion.hs index c20bbe153b7..9fcbb718470 100644 --- a/libs/wire-api/src/Wire/API/MLS/ProtocolVersion.hs +++ b/libs/wire-api/src/Wire/API/MLS/ProtocolVersion.hs @@ -30,7 +30,7 @@ import Imports import Wire.API.MLS.Serialisation import Wire.Arbitrary -newtype ProtocolVersion = ProtocolVersion {pvNumber :: Word8} +newtype ProtocolVersion = ProtocolVersion {pvNumber :: Word16} deriving newtype (Eq, Ord, Show, Binary, Arbitrary, ParseMLS, SerialiseMLS) data ProtocolVersionTag = ProtocolMLS10 | ProtocolMLSDraft11 diff --git a/libs/wire-api/src/Wire/API/MLS/Serialisation.hs b/libs/wire-api/src/Wire/API/MLS/Serialisation.hs index 04472c0dbef..d241bf1ff2a 100644 --- a/libs/wire-api/src/Wire/API/MLS/Serialisation.hs +++ b/libs/wire-api/src/Wire/API/MLS/Serialisation.hs @@ -108,13 +108,11 @@ instance Binary VarInt where get :: Get VarInt get = do w <- lookAhead getWord8 - let x = shiftR (w .&. 0xc0) 6 - maskVarInt = VarInt . (.&. 0x3fffffff) - if - | x == 0b00 -> maskVarInt . fromIntegral <$> getWord8 - | x == 0b01 -> maskVarInt . fromIntegral <$> getWord16be - | x == 0b10 -> maskVarInt . fromIntegral <$> getWord32be - | otherwise -> fail "invalid VarInt prefix" + case shiftR (w .&. 0xc0) 6 of + 0b00 -> VarInt . fromIntegral <$> getWord8 + 0b01 -> VarInt . (.&. 0x3fff) . fromIntegral <$> getWord16be + 0b10 -> VarInt . (.&. 0x3fffffff) . fromIntegral <$> getWord32be + _ -> fail "invalid VarInt prefix" instance SerialiseMLS VarInt where serialiseMLS = put