diff --git a/asyncua/client/ua_client.py b/asyncua/client/ua_client.py index a846fe9ec..390d8b5ef 100644 --- a/asyncua/client/ua_client.py +++ b/asyncua/client/ua_client.py @@ -88,7 +88,7 @@ def _process_received_data(self, data: bytes) -> None: return msg = self._connection.receive_from_header_and_body(header, buf) self._process_received_message(msg) - if header.MessageType == ua.MessageType.SecureOpen: + if header.MessageType == ua.MessageType.SecureOpen and isinstance(msg,ua.Message): params: ua.OpenSecureChannelParameters = self._open_secure_channel_exchange response: ua.OpenSecureChannelResponse = struct_from_binary(ua.OpenSecureChannelResponse, msg.body()) response.ResponseHeader.ServiceResult.check() @@ -107,7 +107,7 @@ def _process_received_data(self, data: bytes) -> None: self.disconnect_socket() return - def _process_received_message(self, msg: Union[ua.Message, ua.Acknowledge, ua.ErrorMessage]): + def _process_received_message(self, msg: Union[None,ua.Message, ua.Acknowledge, ua.ErrorMessage]): if msg is None: pass elif isinstance(msg, ua.Message): diff --git a/asyncua/common/connection.py b/asyncua/common/connection.py index 4e972ad4b..3e433c2fe 100644 --- a/asyncua/common/connection.py +++ b/asyncua/common/connection.py @@ -1,6 +1,7 @@ from dataclasses import dataclass import hashlib from datetime import datetime, timedelta, timezone +from typing import Optional, List, TYPE_CHECKING, Union import logging import copy @@ -14,6 +15,10 @@ class InvalidSignature(Exception): # type: ignore pass +if TYPE_CHECKING: + from asyncua.common.utils import Buffer + from asyncua.ua.uaprotocol_hand import SecurityPolicy, SecurityPolicyFactory + _logger = logging.getLogger('asyncua.uaprotocol') @@ -51,7 +56,7 @@ def is_chunk_count_within_limit(self, sz: int) -> bool: _logger.error("Number of message chunks: %s is > configured max chunk count: %s", sz, self.max_chunk_count) return within_limit - def create_acknowledge_and_set_limits(self, msg: ua.Hello) -> ua.Acknowledge: + def create_acknowledge_and_set_limits(self, msg: "ua.Hello") -> "ua.Acknowledge": ack = ua.Acknowledge() ack.ReceiveBufferSize = min(msg.ReceiveBufferSize, self.max_send_buffer) ack.SendBufferSize = min(msg.SendBufferSize, self.max_recv_buffer) @@ -64,14 +69,14 @@ def create_acknowledge_and_set_limits(self, msg: ua.Hello) -> ua.Acknowledge: _logger.info("updating server limits to: %s", self) return ack - def create_hello_limits(self, msg: ua.Hello) -> ua.Hello: + def create_hello_limits(self, msg: "ua.Hello") -> "ua.Hello": msg.ReceiveBufferSize = self.max_recv_buffer msg.SendBufferSize = self.max_send_buffer msg.MaxChunkCount = self.max_chunk_count msg.MaxMessageSize = self.max_chunk_count return msg - def update_client_limits(self, msg: ua.Acknowledge) -> None: + def update_client_limits(self, msg: "ua.Acknowledge") -> None: self.max_chunk_count = msg.MaxChunkCount self.max_recv_buffer = msg.ReceiveBufferSize self.max_send_buffer = msg.SendBufferSize @@ -105,7 +110,7 @@ def from_binary(security_policy, data): return MessageChunk.from_header_and_body(security_policy, h, data, use_prev_key=True) @staticmethod - def from_header_and_body(security_policy, header, buf, use_prev_key=False): + def from_header_and_body(security_policy: "SecurityPolicy", header, buf, use_prev_key=False): if not len(buf) >= header.body_size: raise ValueError('Full body expected here') data = buf.copy(header.body_size) @@ -156,7 +161,7 @@ def max_body_size(crypto, max_chunk_size): return max_plain_size - ua.SequenceHeader.max_size() - crypto.signature_size() - crypto.min_padding_size() @staticmethod - def message_to_chunks(security_policy, body, max_chunk_size, message_type=ua.MessageType.SecureMessage, channel_id=1, request_id=1, token_id=1): + def message_to_chunks(security_policy: "SecurityPolicy", body, max_chunk_size, message_type=ua.MessageType.SecureMessage, channel_id=1, request_id=1, token_id=1): """ Pack message body (as binary string) into one or more chunks. Size of each chunk will not exceed max_chunk_size. @@ -179,7 +184,7 @@ def message_to_chunks(security_policy, body, max_chunk_size, message_type=ua.Mes crypto = security_policy.symmetric_cryptography max_size = MessageChunk.max_body_size(crypto, max_chunk_size) - chunks = [] + chunks: List[MessageChunk] = [] for i in range(0, len(body), max_size): part = body[i:i + max_size] if i + max_size >= len(body): @@ -204,22 +209,22 @@ class SecureConnection: """ Common logic for client and server """ - def __init__(self, security_policy, limits: TransportLimits): - self._sequence_number = 0 - self._peer_sequence_number = None - self._incoming_parts = [] - self.security_policy = security_policy - self._policies = [] - self._open = False + def __init__(self, security_policy: "SecurityPolicy", limits: TransportLimits) -> None: + self._sequence_number: int = 0 + self._peer_sequence_number: Optional[int] = None + self._incoming_parts: List[MessageChunk] = [] + self.security_policy: SecurityPolicy = security_policy + self._policies: List[SecurityPolicyFactory] = [] + self._open: bool = False self.security_token = ua.ChannelSecurityToken() self.next_security_token = ua.ChannelSecurityToken() self.prev_security_token = ua.ChannelSecurityToken() - self.local_nonce = 0 - self.remote_nonce = 0 - self._allow_prev_token = False - self._limits = limits + self.local_nonce: int = 0 + self.remote_nonce:int = 0 + self._allow_prev_token: bool = False + self._limits: TransportLimits = limits - def set_channel(self, params, request_type, client_nonce): + def set_channel(self, params, request_type, client_nonce) -> None: """ Called on client side when getting secure channel data from server. """ @@ -241,7 +246,7 @@ def set_channel(self, params, request_type, client_nonce): self._allow_prev_token = True - def open(self, params, server): + def open(self, params, server) -> ua.OpenSecureChannelResult: """ Called on server side to open secure channel. """ @@ -276,13 +281,13 @@ def open(self, params, server): return response - def close(self): + def close(self) -> None: self._open = False - def is_open(self): + def is_open(self) -> bool: return self._open - def set_policy_factories(self, policies): + def set_policy_factories(self, policies: "List[SecurityPolicyFactory]") -> None: """ Set a list of available security policies. Use this in servers with multiple endpoints with different security. @@ -290,10 +295,10 @@ def set_policy_factories(self, policies): self._policies = policies @staticmethod - def _policy_matches(policy, uri, mode=None): + def _policy_matches(policy: "SecurityPolicy", uri, mode=None) -> bool: return policy.URI == uri and (mode is None or policy.Mode == mode) - def select_policy(self, uri, peer_certificate, mode=None): + def select_policy(self, uri: str, peer_certificate, mode=None): for policy in self._policies: if policy.matches(uri, mode): self.security_policy = policy.create(peer_certificate) @@ -301,7 +306,7 @@ def select_policy(self, uri, peer_certificate, mode=None): if self.security_policy.URI != uri or (mode is not None and self.security_policy.Mode != mode): raise ua.UaError(f"No matching policy: {uri}, {mode}") - def revolve_tokens(self): + def revolve_tokens(self) -> None: """ Revolve security tokens of the security channel. Start using the next security token negotiated during the renewal of the channel and @@ -389,7 +394,7 @@ def _check_incoming_chunk(self, chunk): raise ua.UaError(f"Received chunk: {chunk} with wrong sequence expecting:" f" {self._peer_sequence_number}, received: {seq_num}," f" spec says to close connection") self._peer_sequence_number = seq_num - def receive_from_header_and_body(self, header, body): + def receive_from_header_and_body(self, header: ua.Header, body: "Buffer") -> Union[None,ua.Message,"ua.Hello",ua.Acknowledge,ua.ErrorMessage]: """ Convert MessageHeader and binary body to OPC UA TCP message (see OPC UA specs Part 6, 7.1: Hello, Acknowledge or ErrorMessage), or a Message @@ -430,7 +435,7 @@ def receive_from_header_and_body(self, header, body): return msg raise ua.UaError(f"Unsupported message type {header.MessageType}") - def _receive(self, msg): + def _receive(self, msg: MessageChunk) -> Optional[ua.Message]: if msg.MessageHeader.packet_size > self._limits.max_recv_buffer: self._incoming_parts = [] _logger.error("Message size: %s is > chunk max size: %s", msg.MessageHeader.packet_size, self._limits.max_recv_buffer) diff --git a/asyncua/crypto/permission_rules.py b/asyncua/crypto/permission_rules.py index 23e5bc3a7..968f873ac 100644 --- a/asyncua/crypto/permission_rules.py +++ b/asyncua/crypto/permission_rules.py @@ -1,7 +1,15 @@ +from abc import abstractmethod from asyncua import ua from asyncua.server.users import UserRole +from abc import ABC -WRITE_TYPES = [ +from typing import TYPE_CHECKING, Tuple, Dict, Set + +if TYPE_CHECKING: + from asyncua.server.users import User + from asyncua.common.utils import Buffer + +WRITE_TYPES: Tuple[int,...] = ( ua.ObjectIds.WriteRequest_Encoding_DefaultBinary, ua.ObjectIds.RegisterServerRequest_Encoding_DefaultBinary, ua.ObjectIds.RegisterServer2Request_Encoding_DefaultBinary, @@ -11,9 +19,9 @@ ua.ObjectIds.DeleteReferencesRequest_Encoding_DefaultBinary, ua.ObjectIds.RegisterNodesRequest_Encoding_DefaultBinary, ua.ObjectIds.UnregisterNodesRequest_Encoding_DefaultBinary -] +) -READ_TYPES = [ +READ_TYPES: Tuple[int,...] = ( ua.ObjectIds.CreateSessionRequest_Encoding_DefaultBinary, ua.ObjectIds.CloseSessionRequest_Encoding_DefaultBinary, ua.ObjectIds.ActivateSessionRequest_Encoding_DefaultBinary, @@ -33,16 +41,17 @@ ua.ObjectIds.CloseSecureChannelRequest_Encoding_DefaultBinary, ua.ObjectIds.CallRequest_Encoding_DefaultBinary, ua.ObjectIds.SetMonitoringModeRequest_Encoding_DefaultBinary, - ua.ObjectIds.SetPublishingModeRequest_Encoding_DefaultBinary -] + ua.ObjectIds.SetPublishingModeRequest_Encoding_DefaultBinary, +) -class PermissionRuleset: +class PermissionRuleset(ABC): """ Base class for permission ruleset """ - def check_validity(self, user, action_type, body): + @abstractmethod + def check_validity(self, user: "User", action_type_id: ua.NodeId, body: "Buffer") -> bool: raise NotImplementedError @@ -52,16 +61,16 @@ class SimpleRoleRuleset(PermissionRuleset): Admins alone can write, admins and users can read, and anonymous users can't do anything. """ - def __init__(self): + def __init__(self) -> None: write_ids = list(map(ua.NodeId, WRITE_TYPES)) read_ids = list(map(ua.NodeId, READ_TYPES)) - self._permission_dict = { + self._permission_dict: Dict[UserRole, Set[ua.NodeId]] = { UserRole.Admin: set().union(write_ids, read_ids), UserRole.User: set().union(read_ids), UserRole.Anonymous: set() } - def check_validity(self, user, action_type_id, body): + def check_validity(self, user: "User", action_type_id: ua.NodeId, body: "Buffer") -> bool: if action_type_id in self._permission_dict[user.role]: return True else: diff --git a/asyncua/crypto/security_policies.py b/asyncua/crypto/security_policies.py index f83e082dc..e64eb86b8 100644 --- a/asyncua/crypto/security_policies.py +++ b/asyncua/crypto/security_policies.py @@ -49,14 +49,14 @@ class Verifier: __metaclass__ = ABCMeta @abstractmethod - def signature_size(self): + def signature_size(self) -> None: pass @abstractmethod - def verify(self, data, signature): + def verify(self, data, signature) -> None: pass - def reset(self): + def reset(self) -> None: attrs = self.__dict__ for k in attrs: attrs[k] = None @@ -70,11 +70,11 @@ class Encryptor: __metaclass__ = ABCMeta @abstractmethod - def plain_block_size(self): + def plain_block_size(self) -> int: pass @abstractmethod - def encrypted_block_size(self): + def encrypted_block_size(self) -> int: pass @abstractmethod @@ -90,18 +90,18 @@ class Decryptor: __metaclass__ = ABCMeta @abstractmethod - def plain_block_size(self): + def plain_block_size(self) -> int: pass @abstractmethod - def encrypted_block_size(self): + def encrypted_block_size(self) -> int: pass @abstractmethod def decrypt(self, data): pass - def reset(self): + def reset(self) -> None: attrs = self.__dict__ for k in attrs: attrs[k] = None diff --git a/asyncua/crypto/uacrypto.py b/asyncua/crypto/uacrypto.py index 0e1a869bc..e124e4c86 100644 --- a/asyncua/crypto/uacrypto.py +++ b/asyncua/crypto/uacrypto.py @@ -1,7 +1,7 @@ from datetime import datetime, timezone from pathlib import Path import aiofiles -from typing import Optional, Union +from typing import Optional, Sequence, Union, List, Tuple from cryptography import x509 from cryptography.hazmat.backends import default_backend @@ -18,6 +18,8 @@ from dataclasses import dataclass import logging + + _logger = logging.getLogger(__name__) @@ -80,7 +82,7 @@ async def load_private_key(path_or_content: Union[str, Path, bytes], return serialization.load_der_private_key(content, password=password, backend=default_backend()) -def der_from_x509(certificate): +def der_from_x509(certificate: x509.Certificate) -> bytes: if certificate is None: return b"" return certificate.public_bytes(serialization.Encoding.DER) @@ -98,7 +100,7 @@ def pem_from_key(private_key: rsa.RSAPrivateKey) -> bytes: return private_key.private_bytes(encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption()) -def sign_sha1(private_key, data): +def sign_sha1(private_key: rsa.RSAPrivateKey, data: bytes) -> bytes: return private_key.sign( data, padding.PKCS1v15(), @@ -106,7 +108,7 @@ def sign_sha1(private_key, data): ) -def sign_sha256(private_key, data): +def sign_sha256(private_key: rsa.RSAPrivateKey, data: bytes) -> bytes: return private_key.sign( data, padding.PKCS1v15(), @@ -114,7 +116,7 @@ def sign_sha256(private_key, data): ) -def sign_pss_sha256(private_key, data): +def sign_pss_sha256(private_key: rsa.RSAPrivateKey, data: bytes) -> bytes: return private_key.sign( data, padding.PSS( @@ -125,8 +127,10 @@ def sign_pss_sha256(private_key, data): ) -def verify_sha1(certificate, data, signature): - certificate.public_key().verify( +def verify_sha1(certificate: x509.Certificate, data: bytes, signature: bytes) -> None: + pub_key = certificate.public_key() + assert isinstance(pub_key,rsa.RSAPublicKey) + pub_key.verify( signature, data, padding.PKCS1v15(), @@ -134,16 +138,20 @@ def verify_sha1(certificate, data, signature): ) -def verify_sha256(certificate, data, signature): - certificate.public_key().verify( +def verify_sha256(certificate: x509.Certificate, data: bytes, signature: bytes) -> None: + pub_key = certificate.public_key() + assert isinstance(pub_key,rsa.RSAPublicKey) + pub_key.verify( signature, data, padding.PKCS1v15(), hashes.SHA256()) -def verify_pss_sha256(certificate, data, signature): - certificate.public_key().verify( +def verify_pss_sha256(certificate: x509.Certificate, data: bytes, signature: bytes) -> None: + pub_key = certificate.public_key() + assert isinstance(pub_key,rsa.RSAPublicKey) + pub_key.verify( signature, data, padding.PSS( @@ -154,7 +162,7 @@ def verify_pss_sha256(certificate, data, signature): ) -def encrypt_basic256(public_key, data): +def encrypt_basic256(public_key: rsa.RSAPublicKey, data: bytes) -> bytes: ciphertext = public_key.encrypt( data, padding.OAEP( @@ -165,7 +173,7 @@ def encrypt_basic256(public_key, data): return ciphertext -def encrypt_rsa_oaep(public_key, data): +def encrypt_rsa_oaep(public_key: rsa.RSAPublicKey, data: bytes) -> bytes: ciphertext = public_key.encrypt( data, padding.OAEP( @@ -176,7 +184,7 @@ def encrypt_rsa_oaep(public_key, data): return ciphertext -def encrypt_rsa_oaep_sha256(public_key, data): +def encrypt_rsa_oaep_sha256(public_key: rsa.RSAPublicKey, data: bytes) -> bytes: ciphertext = public_key.encrypt( data, padding.OAEP( @@ -188,7 +196,7 @@ def encrypt_rsa_oaep_sha256(public_key, data): return ciphertext -def encrypt_rsa15(public_key, data): +def encrypt_rsa15(public_key: rsa.RSAPublicKey, data: bytes) -> bytes: ciphertext = public_key.encrypt( data, padding.PKCS1v15() @@ -196,7 +204,7 @@ def encrypt_rsa15(public_key, data): return ciphertext -def decrypt_rsa_oaep(private_key, data): +def decrypt_rsa_oaep(private_key: rsa.RSAPrivateKey, data: bytes) -> bytes: text = private_key.decrypt( bytes(data), padding.OAEP( @@ -207,7 +215,7 @@ def decrypt_rsa_oaep(private_key, data): return text -def decrypt_rsa_oaep_sha256(private_key, data): +def decrypt_rsa_oaep_sha256(private_key: rsa.RSAPrivateKey, data: bytes) -> bytes: text = private_key.decrypt( bytes(data), padding.OAEP( @@ -219,7 +227,7 @@ def decrypt_rsa_oaep_sha256(private_key, data): return text -def decrypt_rsa15(private_key, data): +def decrypt_rsa15(private_key: rsa.RSAPrivateKey, data: bytes) -> bytes: text = private_key.decrypt( bytes(data), padding.PKCS1v15() @@ -227,28 +235,28 @@ def decrypt_rsa15(private_key, data): return text -def cipher_aes_cbc(key, init_vec): +def cipher_aes_cbc(key: bytes, init_vec: bytes) -> Cipher: # FIXME sonarlint reports critical vulnerability (python:S5542) return Cipher(algorithms.AES(key), modes.CBC(init_vec), default_backend()) -def cipher_encrypt(cipher, data): +def cipher_encrypt(cipher: Cipher, data: bytes) -> bytes: encryptor = cipher.encryptor() return encryptor.update(data) + encryptor.finalize() -def cipher_decrypt(cipher, data): +def cipher_decrypt(cipher: Cipher, data: bytes) -> bytes: decryptor = cipher.decryptor() return decryptor.update(data) + decryptor.finalize() -def hmac_sha1(key, message): +def hmac_sha1(key: bytes, message: bytes) -> bytes: hasher = hmac.HMAC(key, hashes.SHA1(), backend=default_backend()) hasher.update(message) return hasher.finalize() -def hmac_sha256(key, message): +def hmac_sha256(key: bytes, message: bytes) -> bytes: hasher = hmac.HMAC(key, hashes.SHA256(), backend=default_backend()) hasher.update(message) return hasher.finalize() @@ -262,7 +270,7 @@ def sha256_size(): return hashes.SHA256.digest_size -def p_sha1(secret, seed, sizes=()): +def p_sha1(secret: bytes, seed: bytes, sizes: Sequence[int]=()) -> Tuple[bytes,...]: """ Derive one or more keys from secret and seed. (See specs part 6, 6.7.5 and RFC 2246 - TLS v1.0) @@ -278,14 +286,14 @@ def p_sha1(secret, seed, sizes=()): accum = hmac_sha1(secret, accum) result += hmac_sha1(secret, accum + seed) - parts = [] + parts: List[bytes] = [] for size in sizes: parts.append(result[:size]) result = result[size:] return tuple(parts) -def p_sha256(secret, seed, sizes=()): +def p_sha256(secret: bytes, seed: bytes, sizes: Sequence[int]=()) -> Tuple[bytes,...]: """ Derive one or more keys from secret and seed. (See specs part 6, 6.7.5 and RFC 2246 - TLS v1.0) @@ -301,19 +309,19 @@ def p_sha256(secret, seed, sizes=()): accum = hmac_sha256(secret, accum) result += hmac_sha256(secret, accum + seed) - parts = [] + parts: List[bytes] = [] for size in sizes: parts.append(result[:size]) result = result[size:] return tuple(parts) -def x509_name_to_string(name): - parts = [f"{attr.oid._name}={attr.value}" for attr in name] +def x509_name_to_string(name: x509.Name) -> str: + parts = [f"{attr.oid._name}={str(attr.value)}" for attr in name] return ', '.join(parts) -def x509_to_string(cert): +def x509_to_string(cert: x509.Certificate) -> str: """ Convert x509 certificate to human-readable string """ diff --git a/asyncua/server/uaprocessor.py b/asyncua/server/uaprocessor.py index 29000522f..e08fd2066 100644 --- a/asyncua/server/uaprocessor.py +++ b/asyncua/server/uaprocessor.py @@ -2,7 +2,7 @@ import copy import time import logging -from typing import Deque, Optional, Dict +from typing import Deque, Optional, Dict, TYPE_CHECKING from collections import deque from datetime import datetime, timedelta, timezone @@ -12,6 +12,9 @@ from ..common.connection import SecureConnection, TransportLimits from ..common.utils import ServiceError +if TYPE_CHECKING: + from asyncua.common.utils import Buffer + _logger = logging.getLogger(__name__) @@ -40,13 +43,13 @@ def __init__(self, internal_server: InternalServer, transport, limits: Transport # queue for publish results callbacks (using SubscriptionId) # rely on dict insertion order (therefore can't use set()) self._publish_results_subs: Dict[ua.IntegerId, bool] = {} - self._limits = copy.deepcopy(limits) # Copy limits because they get overriden + self._limits: TransportLimits = copy.deepcopy(limits) # Copy limits because they get overriden self._connection = SecureConnection(ua.SecurityPolicy(), self._limits) self._closing: bool = False self._session_watchdog_task: Optional[asyncio.Task] = None self._watchdog_interval: float = 1.0 - def set_policies(self, policies): + def set_policies(self, policies) -> None: self._connection.set_policy_factories(policies) def send_response(self, requesthandle, seqhdr, response, msgtype=ua.MessageType.SecureMessage): @@ -70,7 +73,7 @@ def open_secure_channel(self, algohdr, seqhdr, body): response.Parameters = channel self.send_response(request.RequestHeader.RequestHandle, seqhdr, response, ua.MessageType.SecureOpen) - def get_publish_request(self, subscription_id: ua.IntegerId): + def get_publish_request(self, subscription_id: ua.IntegerId) -> Optional[PublishRequestData]: while True: if not self._publish_requests: # only store one callback per subscription @@ -132,7 +135,7 @@ async def process(self, header, body): raise ServiceError(ua.StatusCodes.BadTcpMessageTypeInvalid) return True - async def process_message(self, seqhdr, body): + async def process_message(self, seqhdr, body: "Buffer"): """ Process incoming messages. """ diff --git a/asyncua/ua/ua_binary.py b/asyncua/ua/ua_binary.py index 9c3ef8742..1d3b7de8b 100644 --- a/asyncua/ua/ua_binary.py +++ b/asyncua/ua/ua_binary.py @@ -690,14 +690,14 @@ def decode(data): return decode -def struct_from_binary(objtype: Union[Type[T], str], data: IO) -> T: +def struct_from_binary(objtype: Union[Type[T], str], data: Union[IO,Buffer]) -> T: """ unpack an ua struct. Arguments are an objtype as Python dataclass or string """ return _create_dataclass_deserializer(objtype)(data) -def header_to_binary(hdr): +def header_to_binary(hdr) -> bytes: b = [struct.pack("<3ss", hdr.MessageType, hdr.ChunkType)] size = hdr.body_size + 8 if hdr.MessageType in (ua.MessageType.SecureOpen, ua.MessageType.SecureClose, ua.MessageType.SecureMessage): @@ -708,7 +708,7 @@ def header_to_binary(hdr): return b"".join(b) -def header_from_binary(data): +def header_from_binary(data) -> "ua.Header": hdr = ua.Header() hdr.MessageType, hdr.ChunkType, hdr.packet_size = struct.unpack("<3scI", data.read(8)) hdr.body_size = hdr.packet_size - 8 @@ -719,7 +719,7 @@ def header_from_binary(data): return hdr -def uatcp_to_binary(message_type, message): +def uatcp_to_binary(message_type, message) -> bytes: """ Convert OPC UA TCP message (see OPC UA specs Part 6, 7.1) to binary. The only supported types are Hello, Acknowledge and ErrorMessage diff --git a/asyncua/ua/uaprotocol_hand.py b/asyncua/ua/uaprotocol_hand.py index d42347eb6..6d6065830 100644 --- a/asyncua/ua/uaprotocol_hand.py +++ b/asyncua/ua/uaprotocol_hand.py @@ -1,11 +1,16 @@ import struct from dataclasses import dataclass, field -from typing import List +from typing import List, TYPE_CHECKING, Optional, Union from asyncua.ua import uaprotocol_auto as auto from asyncua.ua import uatypes from asyncua.common import utils +if TYPE_CHECKING: + from asyncua.common.connection import MessageChunk + from cryptography import x509 + from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes + OPC_TCP_SCHEME = 'opc.tcp' @@ -195,17 +200,23 @@ class SecurityPolicyFactory: Server has one certificate and private key, but needs a separate SecurityPolicy for every client and client's certificate """ - def __init__(self, cls=SecurityPolicy, mode=auto.MessageSecurityMode.None_, certificate=None, private_key=None, permission_ruleset=None): + def __init__(self, + cls=SecurityPolicy, + mode=auto.MessageSecurityMode.None_, + certificate: "Optional[x509.Certificate]"=None, + private_key: "Optional[PrivateKeyTypes]"=None, + permission_ruleset=None + ) -> None: self.cls = cls - self.mode = mode - self.certificate = certificate - self.private_key = private_key + self.mode: auto.MessageSecurityMode = mode + self.certificate: Optional[x509.Certificate] = certificate + self.private_key: Optional[PrivateKeyTypes] = private_key self.permission_ruleset = permission_ruleset - def matches(self, uri, mode=None): + def matches(self, uri: str, mode=None) -> bool: return self.cls.URI == uri and (mode is None or self.mode == mode) - def create(self, peer_certificate): + def create(self, peer_certificate) -> SecurityPolicy: if self.cls is SecurityPolicy: return self.cls(permissions=self.permission_ruleset) else: @@ -213,19 +224,19 @@ def create(self, peer_certificate): class Message: - def __init__(self, chunks): - self._chunks = chunks + def __init__(self, chunks: "List[MessageChunk]") -> None: + self._chunks: List[MessageChunk] = chunks - def request_id(self): + def request_id(self) -> auto.UInt32: return self._chunks[0].SequenceHeader.RequestId - def SequenceHeader(self): + def SequenceHeader(self) -> SequenceHeader: return self._chunks[0].SequenceHeader - def SecurityHeader(self): + def SecurityHeader(self) -> Union[SymmetricAlgorithmHeader, AsymmetricAlgorithmHeader]: return self._chunks[0].SecurityHeader - def body(self): + def body(self) -> utils.Buffer: body = b"".join([c.Body for c in self._chunks]) return utils.Buffer(body)