From 7de3e1de36157fa4506933b85fb8393f8359eeeb Mon Sep 17 00:00:00 2001 From: Jonathan Becker Date: Wed, 2 Oct 2024 13:15:57 +0200 Subject: [PATCH 1/5] add type hints around crypto module --- asyncua/client/ha/reconciliator.py | 3 +- asyncua/client/ua_client.py | 4 +- asyncua/common/connection.py | 53 ++++++++++++---------- asyncua/crypto/permission_rules.py | 29 +++++++----- asyncua/crypto/security_policies.py | 16 +++---- asyncua/crypto/uacrypto.py | 70 ++++++++++++++++------------- asyncua/server/uaprocessor.py | 13 +++--- asyncua/ua/ua_binary.py | 13 +++--- asyncua/ua/uaprotocol_hand.py | 38 ++++++++++------ 9 files changed, 141 insertions(+), 98 deletions(-) diff --git a/asyncua/client/ha/reconciliator.py b/asyncua/client/ha/reconciliator.py index e7055f9c2..1d267c1a5 100644 --- a/asyncua/client/ha/reconciliator.py +++ b/asyncua/client/ha/reconciliator.py @@ -150,7 +150,8 @@ async def reconciliate(self) -> None: for url in valid_urls: digest_ideal = get_digest(ideal_map[url]) digest_real = get_digest(real_map[url]) - if url not in real_map or digest_ideal != digest_real: + #if url not in real_map or digest_ideal != digest_real: + if url not in real_map or ideal_map[url] != real_map[url]: targets.add(url) if not targets: _logger.info( diff --git a/asyncua/client/ua_client.py b/asyncua/client/ua_client.py index a846fe9ec..390d8b5ef 100644 --- a/asyncua/client/ua_client.py +++ b/asyncua/client/ua_client.py @@ -88,7 +88,7 @@ def _process_received_data(self, data: bytes) -> None: return msg = self._connection.receive_from_header_and_body(header, buf) self._process_received_message(msg) - if header.MessageType == ua.MessageType.SecureOpen: + if header.MessageType == ua.MessageType.SecureOpen and isinstance(msg,ua.Message): params: ua.OpenSecureChannelParameters = self._open_secure_channel_exchange response: ua.OpenSecureChannelResponse = struct_from_binary(ua.OpenSecureChannelResponse, msg.body()) response.ResponseHeader.ServiceResult.check() @@ -107,7 +107,7 @@ def _process_received_data(self, data: bytes) -> None: self.disconnect_socket() return - def _process_received_message(self, msg: Union[ua.Message, ua.Acknowledge, ua.ErrorMessage]): + def _process_received_message(self, msg: Union[None,ua.Message, ua.Acknowledge, ua.ErrorMessage]): if msg is None: pass elif isinstance(msg, ua.Message): diff --git a/asyncua/common/connection.py b/asyncua/common/connection.py index 4e972ad4b..6ebb17c49 100644 --- a/asyncua/common/connection.py +++ b/asyncua/common/connection.py @@ -1,6 +1,7 @@ from dataclasses import dataclass import hashlib from datetime import datetime, timedelta, timezone +from typing import Optional, List, TYPE_CHECKING, Union import logging import copy @@ -14,6 +15,10 @@ class InvalidSignature(Exception): # type: ignore pass +if TYPE_CHECKING: + from asyncua.common.utils import Buffer + from asyncua.ua.uaprotocol_hand import SecurityPolicy, SecurityPolicyFactory + _logger = logging.getLogger('asyncua.uaprotocol') @@ -105,7 +110,7 @@ def from_binary(security_policy, data): return MessageChunk.from_header_and_body(security_policy, h, data, use_prev_key=True) @staticmethod - def from_header_and_body(security_policy, header, buf, use_prev_key=False): + def from_header_and_body(security_policy: "SecurityPolicy", header, buf, use_prev_key=False): if not len(buf) >= header.body_size: raise ValueError('Full body expected here') data = buf.copy(header.body_size) @@ -156,7 +161,7 @@ def max_body_size(crypto, max_chunk_size): return max_plain_size - ua.SequenceHeader.max_size() - crypto.signature_size() - crypto.min_padding_size() @staticmethod - def message_to_chunks(security_policy, body, max_chunk_size, message_type=ua.MessageType.SecureMessage, channel_id=1, request_id=1, token_id=1): + def message_to_chunks(security_policy: "SecurityPolicy", body, max_chunk_size, message_type=ua.MessageType.SecureMessage, channel_id=1, request_id=1, token_id=1): """ Pack message body (as binary string) into one or more chunks. Size of each chunk will not exceed max_chunk_size. @@ -179,7 +184,7 @@ def message_to_chunks(security_policy, body, max_chunk_size, message_type=ua.Mes crypto = security_policy.symmetric_cryptography max_size = MessageChunk.max_body_size(crypto, max_chunk_size) - chunks = [] + chunks: List[MessageChunk] = [] for i in range(0, len(body), max_size): part = body[i:i + max_size] if i + max_size >= len(body): @@ -204,22 +209,22 @@ class SecureConnection: """ Common logic for client and server """ - def __init__(self, security_policy, limits: TransportLimits): - self._sequence_number = 0 - self._peer_sequence_number = None - self._incoming_parts = [] - self.security_policy = security_policy - self._policies = [] - self._open = False + def __init__(self, security_policy: "SecurityPolicy", limits: TransportLimits) -> None: + self._sequence_number: int = 0 + self._peer_sequence_number: Optional[int] = None + self._incoming_parts: List[MessageChunk] = [] + self.security_policy: "SecurityPolicy" = security_policy + self._policies: "List[SecurityPolicyFactory]" = [] + self._open: bool = False self.security_token = ua.ChannelSecurityToken() self.next_security_token = ua.ChannelSecurityToken() self.prev_security_token = ua.ChannelSecurityToken() - self.local_nonce = 0 - self.remote_nonce = 0 - self._allow_prev_token = False - self._limits = limits + self.local_nonce: int = 0 + self.remote_nonce:int = 0 + self._allow_prev_token: bool = False + self._limits: TransportLimits = limits - def set_channel(self, params, request_type, client_nonce): + def set_channel(self, params, request_type, client_nonce) -> None: """ Called on client side when getting secure channel data from server. """ @@ -241,7 +246,7 @@ def set_channel(self, params, request_type, client_nonce): self._allow_prev_token = True - def open(self, params, server): + def open(self, params, server) -> ua.OpenSecureChannelResult: """ Called on server side to open secure channel. """ @@ -276,13 +281,13 @@ def open(self, params, server): return response - def close(self): + def close(self) -> None: self._open = False - def is_open(self): + def is_open(self) -> bool: return self._open - def set_policy_factories(self, policies): + def set_policy_factories(self, policies: "List[SecurityPolicyFactory]") -> None: """ Set a list of available security policies. Use this in servers with multiple endpoints with different security. @@ -290,10 +295,10 @@ def set_policy_factories(self, policies): self._policies = policies @staticmethod - def _policy_matches(policy, uri, mode=None): + def _policy_matches(policy: "SecurityPolicy", uri, mode=None) -> bool: return policy.URI == uri and (mode is None or policy.Mode == mode) - def select_policy(self, uri, peer_certificate, mode=None): + def select_policy(self, uri: str, peer_certificate, mode=None): for policy in self._policies: if policy.matches(uri, mode): self.security_policy = policy.create(peer_certificate) @@ -301,7 +306,7 @@ def select_policy(self, uri, peer_certificate, mode=None): if self.security_policy.URI != uri or (mode is not None and self.security_policy.Mode != mode): raise ua.UaError(f"No matching policy: {uri}, {mode}") - def revolve_tokens(self): + def revolve_tokens(self) -> None: """ Revolve security tokens of the security channel. Start using the next security token negotiated during the renewal of the channel and @@ -389,7 +394,7 @@ def _check_incoming_chunk(self, chunk): raise ua.UaError(f"Received chunk: {chunk} with wrong sequence expecting:" f" {self._peer_sequence_number}, received: {seq_num}," f" spec says to close connection") self._peer_sequence_number = seq_num - def receive_from_header_and_body(self, header, body): + def receive_from_header_and_body(self, header: ua.Header, body: "Buffer") -> Union[None,ua.Message,ua.Hello,ua.Acknowledge,ua.ErrorMessage]: """ Convert MessageHeader and binary body to OPC UA TCP message (see OPC UA specs Part 6, 7.1: Hello, Acknowledge or ErrorMessage), or a Message @@ -430,7 +435,7 @@ def receive_from_header_and_body(self, header, body): return msg raise ua.UaError(f"Unsupported message type {header.MessageType}") - def _receive(self, msg): + def _receive(self, msg: MessageChunk) -> Optional[ua.Message]: if msg.MessageHeader.packet_size > self._limits.max_recv_buffer: self._incoming_parts = [] _logger.error("Message size: %s is > chunk max size: %s", msg.MessageHeader.packet_size, self._limits.max_recv_buffer) diff --git a/asyncua/crypto/permission_rules.py b/asyncua/crypto/permission_rules.py index 23e5bc3a7..968f873ac 100644 --- a/asyncua/crypto/permission_rules.py +++ b/asyncua/crypto/permission_rules.py @@ -1,7 +1,15 @@ +from abc import abstractmethod from asyncua import ua from asyncua.server.users import UserRole +from abc import ABC -WRITE_TYPES = [ +from typing import TYPE_CHECKING, Tuple, Dict, Set + +if TYPE_CHECKING: + from asyncua.server.users import User + from asyncua.common.utils import Buffer + +WRITE_TYPES: Tuple[int,...] = ( ua.ObjectIds.WriteRequest_Encoding_DefaultBinary, ua.ObjectIds.RegisterServerRequest_Encoding_DefaultBinary, ua.ObjectIds.RegisterServer2Request_Encoding_DefaultBinary, @@ -11,9 +19,9 @@ ua.ObjectIds.DeleteReferencesRequest_Encoding_DefaultBinary, ua.ObjectIds.RegisterNodesRequest_Encoding_DefaultBinary, ua.ObjectIds.UnregisterNodesRequest_Encoding_DefaultBinary -] +) -READ_TYPES = [ +READ_TYPES: Tuple[int,...] = ( ua.ObjectIds.CreateSessionRequest_Encoding_DefaultBinary, ua.ObjectIds.CloseSessionRequest_Encoding_DefaultBinary, ua.ObjectIds.ActivateSessionRequest_Encoding_DefaultBinary, @@ -33,16 +41,17 @@ ua.ObjectIds.CloseSecureChannelRequest_Encoding_DefaultBinary, ua.ObjectIds.CallRequest_Encoding_DefaultBinary, ua.ObjectIds.SetMonitoringModeRequest_Encoding_DefaultBinary, - ua.ObjectIds.SetPublishingModeRequest_Encoding_DefaultBinary -] + ua.ObjectIds.SetPublishingModeRequest_Encoding_DefaultBinary, +) -class PermissionRuleset: +class PermissionRuleset(ABC): """ Base class for permission ruleset """ - def check_validity(self, user, action_type, body): + @abstractmethod + def check_validity(self, user: "User", action_type_id: ua.NodeId, body: "Buffer") -> bool: raise NotImplementedError @@ -52,16 +61,16 @@ class SimpleRoleRuleset(PermissionRuleset): Admins alone can write, admins and users can read, and anonymous users can't do anything. """ - def __init__(self): + def __init__(self) -> None: write_ids = list(map(ua.NodeId, WRITE_TYPES)) read_ids = list(map(ua.NodeId, READ_TYPES)) - self._permission_dict = { + self._permission_dict: Dict[UserRole, Set[ua.NodeId]] = { UserRole.Admin: set().union(write_ids, read_ids), UserRole.User: set().union(read_ids), UserRole.Anonymous: set() } - def check_validity(self, user, action_type_id, body): + def check_validity(self, user: "User", action_type_id: ua.NodeId, body: "Buffer") -> bool: if action_type_id in self._permission_dict[user.role]: return True else: diff --git a/asyncua/crypto/security_policies.py b/asyncua/crypto/security_policies.py index f83e082dc..e64eb86b8 100644 --- a/asyncua/crypto/security_policies.py +++ b/asyncua/crypto/security_policies.py @@ -49,14 +49,14 @@ class Verifier: __metaclass__ = ABCMeta @abstractmethod - def signature_size(self): + def signature_size(self) -> None: pass @abstractmethod - def verify(self, data, signature): + def verify(self, data, signature) -> None: pass - def reset(self): + def reset(self) -> None: attrs = self.__dict__ for k in attrs: attrs[k] = None @@ -70,11 +70,11 @@ class Encryptor: __metaclass__ = ABCMeta @abstractmethod - def plain_block_size(self): + def plain_block_size(self) -> int: pass @abstractmethod - def encrypted_block_size(self): + def encrypted_block_size(self) -> int: pass @abstractmethod @@ -90,18 +90,18 @@ class Decryptor: __metaclass__ = ABCMeta @abstractmethod - def plain_block_size(self): + def plain_block_size(self) -> int: pass @abstractmethod - def encrypted_block_size(self): + def encrypted_block_size(self) -> int: pass @abstractmethod def decrypt(self, data): pass - def reset(self): + def reset(self) -> None: attrs = self.__dict__ for k in attrs: attrs[k] = None diff --git a/asyncua/crypto/uacrypto.py b/asyncua/crypto/uacrypto.py index 0e1a869bc..ae9dcf1b0 100644 --- a/asyncua/crypto/uacrypto.py +++ b/asyncua/crypto/uacrypto.py @@ -1,7 +1,7 @@ from datetime import datetime, timezone from pathlib import Path import aiofiles -from typing import Optional, Union +from typing import Optional, Sequence, Union, List, Tuple from cryptography import x509 from cryptography.hazmat.backends import default_backend @@ -12,12 +12,16 @@ from cryptography.hazmat.primitives.ciphers import Cipher from cryptography.hazmat.primitives.ciphers import algorithms from cryptography.hazmat.primitives.ciphers import modes +from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PublicKey +from cryptography.hazmat.primitives.asymmetric.x448 import X448PublicKey # We redefine InvalidSignature as part of this module. Do not remove this line. from cryptography.exceptions import InvalidSignature # noqa: F811 from dataclasses import dataclass import logging + + _logger = logging.getLogger(__name__) @@ -80,7 +84,7 @@ async def load_private_key(path_or_content: Union[str, Path, bytes], return serialization.load_der_private_key(content, password=password, backend=default_backend()) -def der_from_x509(certificate): +def der_from_x509(certificate: x509.Certificate) -> bytes: if certificate is None: return b"" return certificate.public_bytes(serialization.Encoding.DER) @@ -98,7 +102,7 @@ def pem_from_key(private_key: rsa.RSAPrivateKey) -> bytes: return private_key.private_bytes(encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption()) -def sign_sha1(private_key, data): +def sign_sha1(private_key: rsa.RSAPrivateKey, data: bytes) -> bytes: return private_key.sign( data, padding.PKCS1v15(), @@ -106,7 +110,7 @@ def sign_sha1(private_key, data): ) -def sign_sha256(private_key, data): +def sign_sha256(private_key: rsa.RSAPrivateKey, data: bytes) -> bytes: return private_key.sign( data, padding.PKCS1v15(), @@ -114,7 +118,7 @@ def sign_sha256(private_key, data): ) -def sign_pss_sha256(private_key, data): +def sign_pss_sha256(private_key: rsa.RSAPrivateKey, data: bytes) -> bytes: return private_key.sign( data, padding.PSS( @@ -125,8 +129,10 @@ def sign_pss_sha256(private_key, data): ) -def verify_sha1(certificate, data, signature): - certificate.public_key().verify( +def verify_sha1(certificate: x509.Certificate, data: bytes, signature: bytes) -> None: + pub_key = certificate.public_key() + assert isinstance(pub_key,rsa.RSAPublicKey) + pub_key.verify( signature, data, padding.PKCS1v15(), @@ -134,16 +140,20 @@ def verify_sha1(certificate, data, signature): ) -def verify_sha256(certificate, data, signature): - certificate.public_key().verify( +def verify_sha256(certificate: x509.Certificate, data: bytes, signature: bytes) -> None: + pub_key = certificate.public_key() + assert isinstance(pub_key,rsa.RSAPublicKey) + pub_key.verify( signature, data, padding.PKCS1v15(), hashes.SHA256()) -def verify_pss_sha256(certificate, data, signature): - certificate.public_key().verify( +def verify_pss_sha256(certificate: x509.Certificate, data: bytes, signature: bytes) -> None: + pub_key = certificate.public_key() + assert isinstance(pub_key,rsa.RSAPublicKey) + pub_key.verify( signature, data, padding.PSS( @@ -154,7 +164,7 @@ def verify_pss_sha256(certificate, data, signature): ) -def encrypt_basic256(public_key, data): +def encrypt_basic256(public_key: rsa.RSAPublicKey, data: bytes) -> bytes: ciphertext = public_key.encrypt( data, padding.OAEP( @@ -165,7 +175,7 @@ def encrypt_basic256(public_key, data): return ciphertext -def encrypt_rsa_oaep(public_key, data): +def encrypt_rsa_oaep(public_key: rsa.RSAPublicKey, data: bytes) -> bytes: ciphertext = public_key.encrypt( data, padding.OAEP( @@ -176,7 +186,7 @@ def encrypt_rsa_oaep(public_key, data): return ciphertext -def encrypt_rsa_oaep_sha256(public_key, data): +def encrypt_rsa_oaep_sha256(public_key: rsa.RSAPublicKey, data: bytes) -> bytes: ciphertext = public_key.encrypt( data, padding.OAEP( @@ -188,7 +198,7 @@ def encrypt_rsa_oaep_sha256(public_key, data): return ciphertext -def encrypt_rsa15(public_key, data): +def encrypt_rsa15(public_key: rsa.RSAPublicKey, data: bytes) -> bytes: ciphertext = public_key.encrypt( data, padding.PKCS1v15() @@ -196,7 +206,7 @@ def encrypt_rsa15(public_key, data): return ciphertext -def decrypt_rsa_oaep(private_key, data): +def decrypt_rsa_oaep(private_key: rsa.RSAPrivateKey, data: bytes) -> bytes: text = private_key.decrypt( bytes(data), padding.OAEP( @@ -207,7 +217,7 @@ def decrypt_rsa_oaep(private_key, data): return text -def decrypt_rsa_oaep_sha256(private_key, data): +def decrypt_rsa_oaep_sha256(private_key: rsa.RSAPrivateKey, data: bytes) -> bytes: text = private_key.decrypt( bytes(data), padding.OAEP( @@ -219,7 +229,7 @@ def decrypt_rsa_oaep_sha256(private_key, data): return text -def decrypt_rsa15(private_key, data): +def decrypt_rsa15(private_key: rsa.RSAPrivateKey, data: bytes) -> bytes: text = private_key.decrypt( bytes(data), padding.PKCS1v15() @@ -227,28 +237,28 @@ def decrypt_rsa15(private_key, data): return text -def cipher_aes_cbc(key, init_vec): +def cipher_aes_cbc(key: bytes, init_vec: bytes) -> Cipher: # FIXME sonarlint reports critical vulnerability (python:S5542) return Cipher(algorithms.AES(key), modes.CBC(init_vec), default_backend()) -def cipher_encrypt(cipher, data): +def cipher_encrypt(cipher: Cipher, data: bytes) -> bytes: encryptor = cipher.encryptor() return encryptor.update(data) + encryptor.finalize() -def cipher_decrypt(cipher, data): +def cipher_decrypt(cipher: Cipher, data: bytes) -> bytes: decryptor = cipher.decryptor() return decryptor.update(data) + decryptor.finalize() -def hmac_sha1(key, message): +def hmac_sha1(key: bytes, message: bytes) -> bytes: hasher = hmac.HMAC(key, hashes.SHA1(), backend=default_backend()) hasher.update(message) return hasher.finalize() -def hmac_sha256(key, message): +def hmac_sha256(key: bytes, message: bytes) -> bytes: hasher = hmac.HMAC(key, hashes.SHA256(), backend=default_backend()) hasher.update(message) return hasher.finalize() @@ -262,7 +272,7 @@ def sha256_size(): return hashes.SHA256.digest_size -def p_sha1(secret, seed, sizes=()): +def p_sha1(secret: bytes, seed: bytes, sizes: Sequence[int]=()) -> Tuple[bytes,...]: """ Derive one or more keys from secret and seed. (See specs part 6, 6.7.5 and RFC 2246 - TLS v1.0) @@ -278,14 +288,14 @@ def p_sha1(secret, seed, sizes=()): accum = hmac_sha1(secret, accum) result += hmac_sha1(secret, accum + seed) - parts = [] + parts: List[bytes] = [] for size in sizes: parts.append(result[:size]) result = result[size:] return tuple(parts) -def p_sha256(secret, seed, sizes=()): +def p_sha256(secret: bytes, seed: bytes, sizes: Sequence[int]=()) -> Tuple[bytes,...]: """ Derive one or more keys from secret and seed. (See specs part 6, 6.7.5 and RFC 2246 - TLS v1.0) @@ -301,19 +311,19 @@ def p_sha256(secret, seed, sizes=()): accum = hmac_sha256(secret, accum) result += hmac_sha256(secret, accum + seed) - parts = [] + parts: List[bytes] = [] for size in sizes: parts.append(result[:size]) result = result[size:] return tuple(parts) -def x509_name_to_string(name): - parts = [f"{attr.oid._name}={attr.value}" for attr in name] +def x509_name_to_string(name: x509.Name) -> str: + parts = [f"{attr.oid._name}={str(attr.value)}" for attr in name] return ', '.join(parts) -def x509_to_string(cert): +def x509_to_string(cert: x509.Certificate) -> str: """ Convert x509 certificate to human-readable string """ diff --git a/asyncua/server/uaprocessor.py b/asyncua/server/uaprocessor.py index 29000522f..e08fd2066 100644 --- a/asyncua/server/uaprocessor.py +++ b/asyncua/server/uaprocessor.py @@ -2,7 +2,7 @@ import copy import time import logging -from typing import Deque, Optional, Dict +from typing import Deque, Optional, Dict, TYPE_CHECKING from collections import deque from datetime import datetime, timedelta, timezone @@ -12,6 +12,9 @@ from ..common.connection import SecureConnection, TransportLimits from ..common.utils import ServiceError +if TYPE_CHECKING: + from asyncua.common.utils import Buffer + _logger = logging.getLogger(__name__) @@ -40,13 +43,13 @@ def __init__(self, internal_server: InternalServer, transport, limits: Transport # queue for publish results callbacks (using SubscriptionId) # rely on dict insertion order (therefore can't use set()) self._publish_results_subs: Dict[ua.IntegerId, bool] = {} - self._limits = copy.deepcopy(limits) # Copy limits because they get overriden + self._limits: TransportLimits = copy.deepcopy(limits) # Copy limits because they get overriden self._connection = SecureConnection(ua.SecurityPolicy(), self._limits) self._closing: bool = False self._session_watchdog_task: Optional[asyncio.Task] = None self._watchdog_interval: float = 1.0 - def set_policies(self, policies): + def set_policies(self, policies) -> None: self._connection.set_policy_factories(policies) def send_response(self, requesthandle, seqhdr, response, msgtype=ua.MessageType.SecureMessage): @@ -70,7 +73,7 @@ def open_secure_channel(self, algohdr, seqhdr, body): response.Parameters = channel self.send_response(request.RequestHeader.RequestHandle, seqhdr, response, ua.MessageType.SecureOpen) - def get_publish_request(self, subscription_id: ua.IntegerId): + def get_publish_request(self, subscription_id: ua.IntegerId) -> Optional[PublishRequestData]: while True: if not self._publish_requests: # only store one callback per subscription @@ -132,7 +135,7 @@ async def process(self, header, body): raise ServiceError(ua.StatusCodes.BadTcpMessageTypeInvalid) return True - async def process_message(self, seqhdr, body): + async def process_message(self, seqhdr, body: "Buffer"): """ Process incoming messages. """ diff --git a/asyncua/ua/ua_binary.py b/asyncua/ua/ua_binary.py index 9c3ef8742..ed19147fa 100644 --- a/asyncua/ua/ua_binary.py +++ b/asyncua/ua/ua_binary.py @@ -6,7 +6,7 @@ import struct import logging from io import BytesIO -from typing import IO, Any, Callable, Optional, Sequence, Type, TypeVar, Union +from typing import IO, Any, Callable, Optional, Sequence, Type, TypeVar, Union, TYPE_CHECKING import typing import uuid from enum import Enum, IntFlag @@ -16,6 +16,9 @@ from ..common.utils import Buffer from .uatypes import type_from_optional, type_is_list, type_is_union, type_from_list, types_or_list_from_union, type_allow_subclass +if TYPE_CHECKING: + from asyncua.common.utils import Buffer + _logger = logging.getLogger(__name__) T = TypeVar('T') @@ -690,14 +693,14 @@ def decode(data): return decode -def struct_from_binary(objtype: Union[Type[T], str], data: IO) -> T: +def struct_from_binary(objtype: Union[Type[T], str], data: Union[IO,Buffer]) -> T: """ unpack an ua struct. Arguments are an objtype as Python dataclass or string """ return _create_dataclass_deserializer(objtype)(data) -def header_to_binary(hdr): +def header_to_binary(hdr) -> bytes: b = [struct.pack("<3ss", hdr.MessageType, hdr.ChunkType)] size = hdr.body_size + 8 if hdr.MessageType in (ua.MessageType.SecureOpen, ua.MessageType.SecureClose, ua.MessageType.SecureMessage): @@ -708,7 +711,7 @@ def header_to_binary(hdr): return b"".join(b) -def header_from_binary(data): +def header_from_binary(data) -> ua.Header: hdr = ua.Header() hdr.MessageType, hdr.ChunkType, hdr.packet_size = struct.unpack("<3scI", data.read(8)) hdr.body_size = hdr.packet_size - 8 @@ -719,7 +722,7 @@ def header_from_binary(data): return hdr -def uatcp_to_binary(message_type, message): +def uatcp_to_binary(message_type, message) -> bytes: """ Convert OPC UA TCP message (see OPC UA specs Part 6, 7.1) to binary. The only supported types are Hello, Acknowledge and ErrorMessage diff --git a/asyncua/ua/uaprotocol_hand.py b/asyncua/ua/uaprotocol_hand.py index d42347eb6..db68f4647 100644 --- a/asyncua/ua/uaprotocol_hand.py +++ b/asyncua/ua/uaprotocol_hand.py @@ -1,11 +1,17 @@ import struct from dataclasses import dataclass, field -from typing import List +from typing import List, TYPE_CHECKING, Optional +from asyncua.common.connection import MessageChunk from asyncua.ua import uaprotocol_auto as auto from asyncua.ua import uatypes from asyncua.common import utils +if TYPE_CHECKING: + from asyncua.common.connection import MessageChunk + from cryptography import x509 + from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes + OPC_TCP_SCHEME = 'opc.tcp' @@ -195,17 +201,23 @@ class SecurityPolicyFactory: Server has one certificate and private key, but needs a separate SecurityPolicy for every client and client's certificate """ - def __init__(self, cls=SecurityPolicy, mode=auto.MessageSecurityMode.None_, certificate=None, private_key=None, permission_ruleset=None): + def __init__(self, + cls=SecurityPolicy, + mode=auto.MessageSecurityMode.None_, + certificate: "Optional[x509.Certificate]"=None, + private_key: "Optional[PrivateKeyTypes]"=None, + permission_ruleset=None + ) -> None: self.cls = cls - self.mode = mode - self.certificate = certificate - self.private_key = private_key + self.mode: auto.MessageSecurityMode = mode + self.certificate: "Optional[x509.Certificate]" = certificate + self.private_key: "Optional[PrivateKeyTypes]" = private_key self.permission_ruleset = permission_ruleset - def matches(self, uri, mode=None): + def matches(self, uri: str, mode=None) -> bool: return self.cls.URI == uri and (mode is None or self.mode == mode) - def create(self, peer_certificate): + def create(self, peer_certificate) -> SecurityPolicy: if self.cls is SecurityPolicy: return self.cls(permissions=self.permission_ruleset) else: @@ -213,19 +225,19 @@ def create(self, peer_certificate): class Message: - def __init__(self, chunks): - self._chunks = chunks + def __init__(self, chunks: "List[MessageChunk]") -> None: + self._chunks: "List[MessageChunk]" = chunks - def request_id(self): + def request_id(self) -> auto.UInt32: return self._chunks[0].SequenceHeader.RequestId - def SequenceHeader(self): + def SequenceHeader(self) -> SequenceHeader: return self._chunks[0].SequenceHeader - def SecurityHeader(self): + def SecurityHeader(self) -> SymmetricAlgorithmHeader | AsymmetricAlgorithmHeader: return self._chunks[0].SecurityHeader - def body(self): + def body(self) -> utils.Buffer: body = b"".join([c.Body for c in self._chunks]) return utils.Buffer(body) From 27cc7824ae537784e0e66191adf8725d9a6a8c9f Mon Sep 17 00:00:00 2001 From: Jonathan Becker Date: Wed, 2 Oct 2024 13:18:19 +0200 Subject: [PATCH 2/5] revert change on reconciliator.py --- asyncua/client/ha/reconciliator.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/asyncua/client/ha/reconciliator.py b/asyncua/client/ha/reconciliator.py index 1d267c1a5..e7055f9c2 100644 --- a/asyncua/client/ha/reconciliator.py +++ b/asyncua/client/ha/reconciliator.py @@ -150,8 +150,7 @@ async def reconciliate(self) -> None: for url in valid_urls: digest_ideal = get_digest(ideal_map[url]) digest_real = get_digest(real_map[url]) - #if url not in real_map or digest_ideal != digest_real: - if url not in real_map or ideal_map[url] != real_map[url]: + if url not in real_map or digest_ideal != digest_real: targets.add(url) if not targets: _logger.info( From 09403bec850a7ff50ff801dad38f474b945b0a04 Mon Sep 17 00:00:00 2001 From: Jonathan Becker Date: Wed, 2 Oct 2024 13:19:56 +0200 Subject: [PATCH 3/5] lint with ruff --- asyncua/common/connection.py | 4 ++-- asyncua/crypto/uacrypto.py | 2 -- asyncua/ua/uaprotocol_hand.py | 6 +++--- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/asyncua/common/connection.py b/asyncua/common/connection.py index 6ebb17c49..b074a8173 100644 --- a/asyncua/common/connection.py +++ b/asyncua/common/connection.py @@ -213,8 +213,8 @@ def __init__(self, security_policy: "SecurityPolicy", limits: TransportLimits) - self._sequence_number: int = 0 self._peer_sequence_number: Optional[int] = None self._incoming_parts: List[MessageChunk] = [] - self.security_policy: "SecurityPolicy" = security_policy - self._policies: "List[SecurityPolicyFactory]" = [] + self.security_policy: SecurityPolicy = security_policy + self._policies: List[SecurityPolicyFactory] = [] self._open: bool = False self.security_token = ua.ChannelSecurityToken() self.next_security_token = ua.ChannelSecurityToken() diff --git a/asyncua/crypto/uacrypto.py b/asyncua/crypto/uacrypto.py index ae9dcf1b0..e124e4c86 100644 --- a/asyncua/crypto/uacrypto.py +++ b/asyncua/crypto/uacrypto.py @@ -12,8 +12,6 @@ from cryptography.hazmat.primitives.ciphers import Cipher from cryptography.hazmat.primitives.ciphers import algorithms from cryptography.hazmat.primitives.ciphers import modes -from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PublicKey -from cryptography.hazmat.primitives.asymmetric.x448 import X448PublicKey # We redefine InvalidSignature as part of this module. Do not remove this line. from cryptography.exceptions import InvalidSignature # noqa: F811 diff --git a/asyncua/ua/uaprotocol_hand.py b/asyncua/ua/uaprotocol_hand.py index db68f4647..5cd5d4e76 100644 --- a/asyncua/ua/uaprotocol_hand.py +++ b/asyncua/ua/uaprotocol_hand.py @@ -210,8 +210,8 @@ def __init__(self, ) -> None: self.cls = cls self.mode: auto.MessageSecurityMode = mode - self.certificate: "Optional[x509.Certificate]" = certificate - self.private_key: "Optional[PrivateKeyTypes]" = private_key + self.certificate: Optional[x509.Certificate] = certificate + self.private_key: Optional[PrivateKeyTypes] = private_key self.permission_ruleset = permission_ruleset def matches(self, uri: str, mode=None) -> bool: @@ -226,7 +226,7 @@ def create(self, peer_certificate) -> SecurityPolicy: class Message: def __init__(self, chunks: "List[MessageChunk]") -> None: - self._chunks: "List[MessageChunk]" = chunks + self._chunks: List[MessageChunk] = chunks def request_id(self) -> auto.UInt32: return self._chunks[0].SequenceHeader.RequestId From a30c9c69b2241d15f5bd0a0f542ada141245d8b3 Mon Sep 17 00:00:00 2001 From: Jonathan Becker Date: Wed, 2 Oct 2024 13:23:45 +0200 Subject: [PATCH 4/5] avoid X | Y syntax for union --- asyncua/ua/uaprotocol_hand.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/asyncua/ua/uaprotocol_hand.py b/asyncua/ua/uaprotocol_hand.py index 5cd5d4e76..cf8fb090a 100644 --- a/asyncua/ua/uaprotocol_hand.py +++ b/asyncua/ua/uaprotocol_hand.py @@ -1,6 +1,6 @@ import struct from dataclasses import dataclass, field -from typing import List, TYPE_CHECKING, Optional +from typing import List, TYPE_CHECKING, Optional, Union from asyncua.common.connection import MessageChunk from asyncua.ua import uaprotocol_auto as auto @@ -234,7 +234,7 @@ def request_id(self) -> auto.UInt32: def SequenceHeader(self) -> SequenceHeader: return self._chunks[0].SequenceHeader - def SecurityHeader(self) -> SymmetricAlgorithmHeader | AsymmetricAlgorithmHeader: + def SecurityHeader(self) -> Union[SymmetricAlgorithmHeader, AsymmetricAlgorithmHeader]: return self._chunks[0].SecurityHeader def body(self) -> utils.Buffer: From 73f37c76f74ee76ff229436f06f50332963cbd15 Mon Sep 17 00:00:00 2001 From: Jonathan Becker Date: Wed, 2 Oct 2024 13:33:07 +0200 Subject: [PATCH 5/5] fix circular import --- asyncua/common/connection.py | 8 ++++---- asyncua/ua/ua_binary.py | 7 ++----- asyncua/ua/uaprotocol_hand.py | 1 - 3 files changed, 6 insertions(+), 10 deletions(-) diff --git a/asyncua/common/connection.py b/asyncua/common/connection.py index b074a8173..3e433c2fe 100644 --- a/asyncua/common/connection.py +++ b/asyncua/common/connection.py @@ -56,7 +56,7 @@ def is_chunk_count_within_limit(self, sz: int) -> bool: _logger.error("Number of message chunks: %s is > configured max chunk count: %s", sz, self.max_chunk_count) return within_limit - def create_acknowledge_and_set_limits(self, msg: ua.Hello) -> ua.Acknowledge: + def create_acknowledge_and_set_limits(self, msg: "ua.Hello") -> "ua.Acknowledge": ack = ua.Acknowledge() ack.ReceiveBufferSize = min(msg.ReceiveBufferSize, self.max_send_buffer) ack.SendBufferSize = min(msg.SendBufferSize, self.max_recv_buffer) @@ -69,14 +69,14 @@ def create_acknowledge_and_set_limits(self, msg: ua.Hello) -> ua.Acknowledge: _logger.info("updating server limits to: %s", self) return ack - def create_hello_limits(self, msg: ua.Hello) -> ua.Hello: + def create_hello_limits(self, msg: "ua.Hello") -> "ua.Hello": msg.ReceiveBufferSize = self.max_recv_buffer msg.SendBufferSize = self.max_send_buffer msg.MaxChunkCount = self.max_chunk_count msg.MaxMessageSize = self.max_chunk_count return msg - def update_client_limits(self, msg: ua.Acknowledge) -> None: + def update_client_limits(self, msg: "ua.Acknowledge") -> None: self.max_chunk_count = msg.MaxChunkCount self.max_recv_buffer = msg.ReceiveBufferSize self.max_send_buffer = msg.SendBufferSize @@ -394,7 +394,7 @@ def _check_incoming_chunk(self, chunk): raise ua.UaError(f"Received chunk: {chunk} with wrong sequence expecting:" f" {self._peer_sequence_number}, received: {seq_num}," f" spec says to close connection") self._peer_sequence_number = seq_num - def receive_from_header_and_body(self, header: ua.Header, body: "Buffer") -> Union[None,ua.Message,ua.Hello,ua.Acknowledge,ua.ErrorMessage]: + def receive_from_header_and_body(self, header: ua.Header, body: "Buffer") -> Union[None,ua.Message,"ua.Hello",ua.Acknowledge,ua.ErrorMessage]: """ Convert MessageHeader and binary body to OPC UA TCP message (see OPC UA specs Part 6, 7.1: Hello, Acknowledge or ErrorMessage), or a Message diff --git a/asyncua/ua/ua_binary.py b/asyncua/ua/ua_binary.py index ed19147fa..1d3b7de8b 100644 --- a/asyncua/ua/ua_binary.py +++ b/asyncua/ua/ua_binary.py @@ -6,7 +6,7 @@ import struct import logging from io import BytesIO -from typing import IO, Any, Callable, Optional, Sequence, Type, TypeVar, Union, TYPE_CHECKING +from typing import IO, Any, Callable, Optional, Sequence, Type, TypeVar, Union import typing import uuid from enum import Enum, IntFlag @@ -16,9 +16,6 @@ from ..common.utils import Buffer from .uatypes import type_from_optional, type_is_list, type_is_union, type_from_list, types_or_list_from_union, type_allow_subclass -if TYPE_CHECKING: - from asyncua.common.utils import Buffer - _logger = logging.getLogger(__name__) T = TypeVar('T') @@ -711,7 +708,7 @@ def header_to_binary(hdr) -> bytes: return b"".join(b) -def header_from_binary(data) -> ua.Header: +def header_from_binary(data) -> "ua.Header": hdr = ua.Header() hdr.MessageType, hdr.ChunkType, hdr.packet_size = struct.unpack("<3scI", data.read(8)) hdr.body_size = hdr.packet_size - 8 diff --git a/asyncua/ua/uaprotocol_hand.py b/asyncua/ua/uaprotocol_hand.py index cf8fb090a..6d6065830 100644 --- a/asyncua/ua/uaprotocol_hand.py +++ b/asyncua/ua/uaprotocol_hand.py @@ -2,7 +2,6 @@ from dataclasses import dataclass, field from typing import List, TYPE_CHECKING, Optional, Union -from asyncua.common.connection import MessageChunk from asyncua.ua import uaprotocol_auto as auto from asyncua.ua import uatypes from asyncua.common import utils