From b8fedf71a8ccb145b80e1d3f87c9b77b196ff20e Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Wed, 27 May 2020 12:10:43 +0200 Subject: [PATCH 01/31] Check if it is allow to start a pending --- ospd/command/command.py | 44 ---------------------------------------- ospd/ospd.py | 45 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 44 deletions(-) diff --git a/ospd/command/command.py b/ospd/command/command.py index 295df59f..7903d804 100644 --- a/ospd/command/command.py +++ b/ospd/command/command.py @@ -471,38 +471,6 @@ def get_elements(self): return elements - def is_new_scan_allowed(self) -> bool: - """ Check if max_scans has been reached. - - Return: - True if a new scan can be launch. - """ - if (self._daemon.max_scans == 0) or ( - len(self._daemon.scan_processes) < self._daemon.max_scans - ): - return True - - return False - - def is_enough_free_memory(self) -> bool: - """ Check if there is enough free memory in the system to run - a new scan. The necessary memory is a rough calculation and very - conservative. - - Return: - True if there is enough memory for a new scan. - """ - - ps_process = psutil.Process() - proc_memory = ps_process.memory_info().rss - - free_mem = psutil.virtual_memory().free - - if free_mem > (4 * proc_memory): - return True - - return False - def handle_xml(self, xml: Element) -> bytes: """ Handles command. @@ -510,18 +478,6 @@ def handle_xml(self, xml: Element) -> bytes: Response string for command. """ - if self._daemon.check_free_memory and not self.is_enough_free_memory(): - raise OspdCommandError( - 'Not possible to run a new scan. Not enough free memory.', - 'start_scan', - ) - - if not self.is_new_scan_allowed(): - raise OspdCommandError( - 'Not possible to run a new scan. Max scan limit reached.', - 'start_scan', - ) - target_str = xml.get('target') ports_str = xml.get('ports') diff --git a/ospd/ospd.py b/ospd/ospd.py index 88d623cb..e1dfd69a 100644 --- a/ospd/ospd.py +++ b/ospd/ospd.py @@ -1182,6 +1182,19 @@ def run(self) -> None: logger.info("Received Ctrl-C shutting-down ...") def start_pending_scans(self): + + if self._daemon.check_free_memory and not self.is_enough_free_memory(): + raise OspdCommandError( + 'Not possible to run a new scan. Not enough free memory.', + 'start_scan', + ) + + if not self.is_new_scan_allowed(): + raise OspdCommandError( + 'Not possible to run a new scan. Max scan limit reached.', + 'start_scan', + ) + for scan_id in self.scan_collection.ids_iterator(): if self.get_scan_status(scan_id) == ScanStatus.PENDING: scan_func = self.start_scan @@ -1190,6 +1203,38 @@ def start_pending_scans(self): scan_process.start() self.set_scan_status(scan_id, ScanStatus.INIT) + def is_new_scan_allowed(self) -> bool: + """ Check if max_scans has been reached. + + Return: + True if a new scan can be launch. + """ + if (self.max_scans == 0) or ( + len(self.scan_processes) < self.max_scans + ): + return True + + return False + + def is_enough_free_memory(self) -> bool: + """ Check if there is enough free memory in the system to run + a new scan. The necessary memory is a rough calculation and very + conservative. + + Return: + True if there is enough memory for a new scan. + """ + + ps_process = psutil.Process() + proc_memory = ps_process.memory_info().rss + + free_mem = psutil.virtual_memory().free + + if free_mem > (4 * proc_memory): + return True + + return False + def scheduler(self): """ Should be implemented by subclass in case of need to run tasks periodically. """ From ccff9b9a3e89d820159da82c34b99a0b4fb9b81a Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Wed, 27 May 2020 12:11:22 +0200 Subject: [PATCH 02/31] Check if it is allow to start a pending. Move the check from star_scan command handler to start_pending_scans(). This is because the new scans are not rejected anymore, but will not be started if the max_scans value was reached or if there is no enough free memory. --- ospd/ospd.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/ospd/ospd.py b/ospd/ospd.py index e1dfd69a..ae21dca8 100644 --- a/ospd/ospd.py +++ b/ospd/ospd.py @@ -1183,7 +1183,7 @@ def run(self) -> None: def start_pending_scans(self): - if self._daemon.check_free_memory and not self.is_enough_free_memory(): + if self._daemon.check_free_memory and not self.is_enough_free_memory(): raise OspdCommandError( 'Not possible to run a new scan. Not enough free memory.', 'start_scan', @@ -1209,9 +1209,7 @@ def is_new_scan_allowed(self) -> bool: Return: True if a new scan can be launch. """ - if (self.max_scans == 0) or ( - len(self.scan_processes) < self.max_scans - ): + if (self.max_scans == 0) or (len(self.scan_processes) < self.max_scans): return True return False From c6c47943bf15142c8031041b0d48f6238203ea46 Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Wed, 27 May 2020 12:19:09 +0200 Subject: [PATCH 03/31] Check for each pending scan if it is allowed. Also, replace the error raise for a debug log message. --- ospd/ospd.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/ospd/ospd.py b/ospd/ospd.py index ae21dca8..00f2a582 100644 --- a/ospd/ospd.py +++ b/ospd/ospd.py @@ -1182,20 +1182,21 @@ def run(self) -> None: logger.info("Received Ctrl-C shutting-down ...") def start_pending_scans(self): + """ Starts a pending scan if it is allowed """ - if self._daemon.check_free_memory and not self.is_enough_free_memory(): - raise OspdCommandError( - 'Not possible to run a new scan. Not enough free memory.', - 'start_scan', - ) + for scan_id in self.scan_collection.ids_iterator(): + if not self.is_new_scan_allowed(): + logger.debug( + 'Not possible to run a new scan. Max scan limit reached.' + ) + return - if not self.is_new_scan_allowed(): - raise OspdCommandError( - 'Not possible to run a new scan. Max scan limit reached.', - 'start_scan', - ) + if self.check_free_memory and not self.is_enough_free_memory(): + logger.debug( + 'Not possible to run a new scan. Not enough free memory.' + ) + return - for scan_id in self.scan_collection.ids_iterator(): if self.get_scan_status(scan_id) == ScanStatus.PENDING: scan_func = self.start_scan scan_process = create_process(func=scan_func, args=(scan_id,)) From 07402c9ac10a84a9041118c151c0bcb9de999ef1 Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Wed, 27 May 2020 14:45:35 +0200 Subject: [PATCH 04/31] Add new argument file_storage_dir Receive the directory path where the picked scan info will be stored. Initialize the ScanCollection() with this path. --- ospd/ospd.py | 3 ++- ospd/scan.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/ospd/ospd.py b/ospd/ospd.py index 00f2a582..8c1aff16 100644 --- a/ospd/ospd.py +++ b/ospd/ospd.py @@ -117,10 +117,11 @@ def __init__( storage=None, max_scans=0, check_free_memory=False, + file_storage_dir='/var/run/ospd', **kwargs ): # pylint: disable=unused-argument """ Initializes the daemon's internal data. """ - self.scan_collection = ScanCollection() + self.scan_collection = ScanCollection(file_storage_dir) self.scan_processes = dict() self.daemon_info = dict() diff --git a/ospd/scan.py b/ospd/scan.py index 3d790651..e627baf9 100644 --- a/ospd/scan.py +++ b/ospd/scan.py @@ -56,13 +56,14 @@ class ScanCollection: """ - def __init__(self) -> None: + def __init__(self, file_storage_dir) -> None: """ Initialize the Scan Collection. """ self.data_manager = ( None ) # type: Optional[multiprocessing.managers.SyncManager] self.scans_table = dict() # type: Dict + self.file_storage_dir = file_storage_dir def init(self): self.data_manager = multiprocessing.Manager() From de53fb824ad842e87633fc2f1e26973766c14c6d Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Wed, 27 May 2020 16:36:11 +0200 Subject: [PATCH 05/31] Create the the dictionary for the new scan in the scan table. Only the minimal information is stored. Other information is prepared in a dictionary to be pickled later. The credentials are taken from the target dictionary and stored in the scan_table. This prevent to pickle sensitive information. Storing credentials does not have big impact in the memory usage. --- ospd/scan.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/ospd/scan.py b/ospd/scan.py index e627baf9..79475486 100644 --- a/ospd/scan.py +++ b/ospd/scan.py @@ -211,24 +211,21 @@ def create_scan( ) -> str: """ Creates a new scan with provided scan information. """ - if not target: - target = {} - if not options: options = dict() + credentials = target.pop('credentials') + scan_info = self.data_manager.dict() # type: Dict - scan_info['results'] = list() - scan_info['progress'] = 0 - scan_info['target_progress'] = dict() - scan_info['count_alive'] = 0 - scan_info['count_dead'] = 0 - scan_info['target'] = target - scan_info['vts'] = vts - scan_info['options'] = options - scan_info['start_time'] = int(time.time()) - scan_info['end_time'] = 0 scan_info['status'] = ScanStatus.PENDING + scan_info['credentials'] = credentials + scan_info['start_time'] = int(time.time()) + + scan_info_to_pikle = { + 'target': target, + 'options': options, + 'vts': vts, + } if scan_id is None or scan_id == '': scan_id = str(uuid.uuid4()) @@ -370,7 +367,7 @@ def get_credentials(self, scan_id: str) -> Dict[str, Dict[str, str]]: """ Get a scan's credential list. It return dictionary with the corresponding credential for a given target. """ - return self.scans_table[scan_id]['target'].get('credentials') + return self.scans_table[scan_id].get('credentials') def get_target_options(self, scan_id: str) -> Dict[str, str]: """ Get a scan's target option dictionary. From 00d3fae6ffed5be4dccebf3c87ba860812eebf90 Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Wed, 27 May 2020 16:41:04 +0200 Subject: [PATCH 06/31] Add method to pickle the scan info into a file. Call the new method from create_scan() --- ospd/scan.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/ospd/scan.py b/ospd/scan.py index 79475486..3a427c54 100644 --- a/ospd/scan.py +++ b/ospd/scan.py @@ -19,7 +19,9 @@ import multiprocessing import time import uuid +import pickle +from pathlib import Path from collections import OrderedDict from enum import Enum from typing import List, Any, Dict, Iterator, Optional, Iterable, Union @@ -202,6 +204,13 @@ def ids_iterator(self) -> Iterator[str]: return iter(self.scans_table.keys()) + def pickle_scan_info(self, scan_id, scan_info): + """ Pickle a scan_info object and stored it in a file named as the scan_id""" + + storage_file_path = Path(self.file_storage_dir) / scan_id + with storage_file_path.open('wb') as scan_info_f: + pickle.dump(scan_info, scan_info_f) + def create_scan( self, scan_id: str = '', @@ -230,6 +239,8 @@ def create_scan( if scan_id is None or scan_id == '': scan_id = str(uuid.uuid4()) + self.pickle_scan_info(scan_id, scan_info_to_pikle) + scan_info['scan_id'] = scan_id self.scans_table[scan_id] = scan_info From f2a30ef40246f9b7f600e2be0dfe35fffbdc531e Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Wed, 27 May 2020 16:42:27 +0200 Subject: [PATCH 07/31] Add method to unpickle the scan_info. The info is stored in the corresponding slot in the scan_table Call the method to unpickle from start_scan(). The file will be deleted. --- ospd/ospd.py | 1 + ospd/scan.py | 26 ++++++++++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/ospd/ospd.py b/ospd/ospd.py index 8c1aff16..14e35250 100644 --- a/ospd/ospd.py +++ b/ospd/ospd.py @@ -1199,6 +1199,7 @@ def start_pending_scans(self): return if self.get_scan_status(scan_id) == ScanStatus.PENDING: + self.scan_collection.unpikle_scan_info(scan_id) scan_func = self.start_scan scan_process = create_process(func=scan_func, args=(scan_id,)) self.scan_processes[scan_id] = scan_process diff --git a/ospd/scan.py b/ospd/scan.py index 3a427c54..4c970fb0 100644 --- a/ospd/scan.py +++ b/ospd/scan.py @@ -211,6 +211,32 @@ def pickle_scan_info(self, scan_id, scan_info): with storage_file_path.open('wb') as scan_info_f: pickle.dump(scan_info, scan_info_f) + def unpikle_scan_info(self, scan_id): + """ Unpikle the scan_info correspinding to the scan_id and store it in the + scan_table """ + + storage_file_path = Path(self.file_storage_dir) / scan_id + unpikled_scan_info = None + with storage_file_path.open('rb') as scan_info_f: + unpikled_scan_info = pickle.load(scan_info_f) + + scan_info = self.scans_table.get(scan_id) + + scan_info['results'] = list() + scan_info['progress'] = 0 + scan_info['target_progress'] = dict() + scan_info['count_alive'] = 0 + scan_info['count_dead'] = 0 + scan_info['target'] = unpikled_scan_info.pop('target') + scan_info['vts'] = unpikled_scan_info.pop('vts') + scan_info['options'] = unpikled_scan_info.pop('options') + scan_info['start_time'] = int(time.time()) + scan_info['end_time'] = 0 + + self.scans_table[scan_id] = scan_info + + storage_file_path.unlink() + def create_scan( self, scan_id: str = '', From 6bc4a8f1a9b82aee924219e429f694b5e95527ac Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Wed, 27 May 2020 17:04:32 +0200 Subject: [PATCH 08/31] Support stopping a PENDING scan. Removes the file with the pickled scan_info and set the scan status to STOPPED. This allows the client to delete the scan --- ospd/ospd.py | 6 ++++++ ospd/scan.py | 5 +++++ 2 files changed, 11 insertions(+) diff --git a/ospd/ospd.py b/ospd/ospd.py index 14e35250..961a399e 100644 --- a/ospd/ospd.py +++ b/ospd/ospd.py @@ -367,6 +367,12 @@ def process_targets_element(scanner_target) -> List: return OspRequest.process_target_element(scanner_target) def stop_scan(self, scan_id: str) -> None: + + if self.get_scan_status(scan_id) == ScanStatus.PENDING: + self.scan_collection.remove_file_pickled_scan_info(scan_id) + self.set_scan_status(scan_id, ScanStatus.STOPPED) + return + scan_process = self.scan_processes.get(scan_id) if not scan_process: raise OspdCommandError( diff --git a/ospd/scan.py b/ospd/scan.py index 4c970fb0..67b6af4a 100644 --- a/ospd/scan.py +++ b/ospd/scan.py @@ -204,6 +204,11 @@ def ids_iterator(self) -> Iterator[str]: return iter(self.scans_table.keys()) + def remove_file_pickled_scan_info(self, scan_id): + """ Remove the file containing a scan_info pickled object """ + storage_file_path = Path(self.file_storage_dir) / scan_id + storage_file_path.unlink(missing_ok=True) + def pickle_scan_info(self, scan_id, scan_info): """ Pickle a scan_info object and stored it in a file named as the scan_id""" From 9c6c3cf2f4ab3517f928aa748a402b6212f827a5 Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Wed, 27 May 2020 18:40:38 +0200 Subject: [PATCH 09/31] Support deleting a stopped scan which previous status was PENDING --- ospd/ospd.py | 39 +++++++++++++++++++++++++++------------ ospd/scan.py | 8 ++++---- 2 files changed, 31 insertions(+), 16 deletions(-) diff --git a/ospd/ospd.py b/ospd/ospd.py index 961a399e..596ef148 100644 --- a/ospd/ospd.py +++ b/ospd/ospd.py @@ -688,7 +688,7 @@ def delete_scan(self, scan_id: str) -> int: self.scan_processes[scan_id].join() exitcode = self.scan_processes[scan_id].exitcode except KeyError: - logger.debug('Scan process for %s not found', scan_id) + logger.debug('Scan process for %s never started,', scan_id) if exitcode or exitcode == 0: del self.scan_processes[scan_id] @@ -765,16 +765,27 @@ def get_scan_xml( if not scan_id: return Element('scan') - target = self.get_scan_host(scan_id) - progress = self.get_scan_progress(scan_id) - status = self.get_scan_status(scan_id) - start_time = self.get_scan_start_time(scan_id) - end_time = self.get_scan_end_time(scan_id) - response = Element('scan') + if self.get_scan_status(scan_id) == ScanStatus.PENDING: + target = '' + scan_progress = 0 + status = self.get_scan_status(scan_id) + start_time = 0 + end_time = 0 + response = Element('scan') + detailed = False + progress = False + else: + target = self.get_scan_host(scan_id) + scan_progress = self.get_scan_progress(scan_id) + status = self.get_scan_status(scan_id) + start_time = self.get_scan_start_time(scan_id) + end_time = self.get_scan_end_time(scan_id) + response = Element('scan') + for name, value in [ ('id', scan_id), ('target', target), - ('progress', progress), + ('progress', scan_progress), ('status', status.name.lower()), ('start_time', start_time), ('end_time', end_time), @@ -1314,13 +1325,17 @@ def clean_forgotten_scans(self) -> None: def check_scan_process(self, scan_id: str) -> None: """ Check the scan's process, and terminate the scan if not alive. """ - scan_process = self.scan_processes.get(scan_id) - progress = self.get_scan_progress(scan_id) - if self.get_scan_status(scan_id) == ScanStatus.PENDING: return - if progress < PROGRESS_FINISHED and not scan_process.is_alive(): + scan_process = self.scan_processes.get(scan_id) + progress = self.get_scan_progress(scan_id) + + if ( + progress < PROGRESS_FINISHED + and scan_process + and not scan_process.is_alive() + ): if not self.get_scan_status(scan_id) == ScanStatus.STOPPED: self.set_scan_status(scan_id, ScanStatus.STOPPED) self.add_scan_error( diff --git a/ospd/scan.py b/ospd/scan.py index 67b6af4a..b078d45c 100644 --- a/ospd/scan.py +++ b/ospd/scan.py @@ -207,7 +207,7 @@ def ids_iterator(self) -> Iterator[str]: def remove_file_pickled_scan_info(self, scan_id): """ Remove the file containing a scan_info pickled object """ storage_file_path = Path(self.file_storage_dir) / scan_id - storage_file_path.unlink(missing_ok=True) + storage_file_path.unlink() def pickle_scan_info(self, scan_id, scan_info): """ Pickle a scan_info object and stored it in a file named as the scan_id""" @@ -286,12 +286,12 @@ def set_status(self, scan_id: str, status: ScanStatus) -> None: def get_status(self, scan_id: str) -> ScanStatus: """ Get scan_id scans's status.""" - return self.scans_table[scan_id]['status'] + return self.scans_table[scan_id].get('status') def get_options(self, scan_id: str) -> Dict: """ Get scan_id scan's options list. """ - return self.scans_table[scan_id]['options'] + return self.scans_table[scan_id].get('options') def set_option(self, scan_id, name: str, value: Any) -> None: """ Set a scan_id scan's name option to value. """ @@ -301,7 +301,7 @@ def set_option(self, scan_id, name: str, value: Any) -> None: def get_progress(self, scan_id: str) -> int: """ Get a scan's current progress value. """ - return self.scans_table[scan_id]['progress'] + return self.scans_table[scan_id].get('progress', 0) def get_count_dead(self, scan_id: str) -> int: """ Get a scan's current dead host count. """ From fe63af136514e0b500f6e9524c752c7fbe75fcfe Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Wed, 27 May 2020 19:04:07 +0200 Subject: [PATCH 10/31] Check for existence of the scan before trying to stop a PENDING scan --- ospd/ospd.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ospd/ospd.py b/ospd/ospd.py index 596ef148..6630badf 100644 --- a/ospd/ospd.py +++ b/ospd/ospd.py @@ -367,8 +367,10 @@ def process_targets_element(scanner_target) -> List: return OspRequest.process_target_element(scanner_target) def stop_scan(self, scan_id: str) -> None: - - if self.get_scan_status(scan_id) == ScanStatus.PENDING: + if ( + scan_id in self.scan_collection.ids_iterator() + and self.get_scan_status(scan_id) == ScanStatus.PENDING + ): self.scan_collection.remove_file_pickled_scan_info(scan_id) self.set_scan_status(scan_id, ScanStatus.STOPPED) return From 4646c0c673e3d7e009574b68020e56ab8690b296 Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Thu, 28 May 2020 12:43:05 +0200 Subject: [PATCH 11/31] Add class to handle pickled data in files --- ospd/datapickler.py | 73 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 ospd/datapickler.py diff --git a/ospd/datapickler.py b/ospd/datapickler.py new file mode 100644 index 00000000..dd2430f4 --- /dev/null +++ b/ospd/datapickler.py @@ -0,0 +1,73 @@ +# Copyright (C) 2014-2020 Greenbone Networks GmbH +# +# SPDX-License-Identifier: AGPL-3.0-or-later +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +""" Pikle Handler class +""" + +import logging +import pickle + +from pathlib import Path +from ospd.errors import OspdCommandError + +logger = logging.getLogger(__name__) + + +class DataPickler: + def __init__(self, storage_path): + self._storage_path = storage_path + + def remove_file(self, filename): + """ Remove the file containing a scan_info pickled object """ + storage_file_path = Path(self._storage_path) / filename + storage_file_path.unlink() + + def store_data(self, filename, data_object): + """ Pickle a object and store it in a file named""" + storage_file_path = Path(self._storage_path) / filename + + try: + # create parent directories recursively + parent_dir = storage_file_path.parent + parent_dir.mkdir(parents=True, exist_ok=True) + except Exception as e: # pylint: disable=broad-except + raise OspdCommandError( + 'Not possible to store scan info for %s. %s' % (filename, e), + 'start_scan', + ) + + try: + with storage_file_path.open('wb') as scan_info_f: + pickle.dump(data_object, scan_info_f) + except Exception as e: + raise OspdCommandError( + 'Not possible to store scan info for %s. %s' % (filename, e), + 'start_scan', + ) + + def load_data(self, filename): + """ Unpikle stored data """ + + storage_file_path = Path(self._storage_path) / filename + unpikled_scan_info = None + try: + with storage_file_path.open('rb') as scan_info_f: + unpikled_scan_info = pickle.load(scan_info_f) + except Exception as e: + logger.error('Not possible to load data from %s. %s', filename, e) + + return unpikled_scan_info From 1f741f49e7d115d8c410285df185b6edc192d77c Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Thu, 28 May 2020 12:43:39 +0200 Subject: [PATCH 12/31] Use DataPickler class and improve error handling. --- ospd/ospd.py | 8 +++++++- ospd/scan.py | 52 +++++++++++++++++++++++++++------------------------- 2 files changed, 34 insertions(+), 26 deletions(-) diff --git a/ospd/ospd.py b/ospd/ospd.py index 6630badf..86952671 100644 --- a/ospd/ospd.py +++ b/ospd/ospd.py @@ -1218,7 +1218,13 @@ def start_pending_scans(self): return if self.get_scan_status(scan_id) == ScanStatus.PENDING: - self.scan_collection.unpikle_scan_info(scan_id) + try: + self.scan_collection.unpickle_scan_info(scan_id) + except OspdCommandError as e: + logger.error("Start scan error %s", e) + self.stop_scan(scan_id) + continue + scan_func = self.start_scan scan_process = create_process(func=scan_func, args=(scan_id,)) self.scan_processes[scan_id] = scan_process diff --git a/ospd/scan.py b/ospd/scan.py index b078d45c..41a9da89 100644 --- a/ospd/scan.py +++ b/ospd/scan.py @@ -27,6 +27,8 @@ from typing import List, Any, Dict, Iterator, Optional, Iterable, Union from ospd.network import target_str_to_list +from ospd.datapickler import DataPickler +from ospd.errors import OspdCommandError LOGGER = logging.getLogger(__name__) @@ -205,25 +207,20 @@ def ids_iterator(self) -> Iterator[str]: return iter(self.scans_table.keys()) def remove_file_pickled_scan_info(self, scan_id): - """ Remove the file containing a scan_info pickled object """ - storage_file_path = Path(self.file_storage_dir) / scan_id - storage_file_path.unlink() - - def pickle_scan_info(self, scan_id, scan_info): - """ Pickle a scan_info object and stored it in a file named as the scan_id""" - - storage_file_path = Path(self.file_storage_dir) / scan_id - with storage_file_path.open('wb') as scan_info_f: - pickle.dump(scan_info, scan_info_f) - - def unpikle_scan_info(self, scan_id): - """ Unpikle the scan_info correspinding to the scan_id and store it in the - scan_table """ - - storage_file_path = Path(self.file_storage_dir) / scan_id - unpikled_scan_info = None - with storage_file_path.open('rb') as scan_info_f: - unpikled_scan_info = pickle.load(scan_info_f) + pickler = DataPickler(self.file_storage_dir) + pickler.remove_file(scan_id) + + def unpickle_scan_info(self, scan_id): + """ Unpickle a stored scan_inf correspinding to the scan_id + and store it in the scan_table """ + pickler = DataPickler(self.file_storage_dir) + unpickled_scan_info = pickler.load_data(scan_id) + + if not unpickled_scan_info: + raise OspdCommandError( + 'Not possible to unpickle stored scan info for %s' % scan_id, + 'start_scan', + ) scan_info = self.scans_table.get(scan_id) @@ -232,15 +229,15 @@ def unpikle_scan_info(self, scan_id): scan_info['target_progress'] = dict() scan_info['count_alive'] = 0 scan_info['count_dead'] = 0 - scan_info['target'] = unpikled_scan_info.pop('target') - scan_info['vts'] = unpikled_scan_info.pop('vts') - scan_info['options'] = unpikled_scan_info.pop('options') + scan_info['target'] = unpickled_scan_info.pop('target') + scan_info['vts'] = unpickled_scan_info.pop('vts') + scan_info['options'] = unpickled_scan_info.pop('options') scan_info['start_time'] = int(time.time()) scan_info['end_time'] = 0 self.scans_table[scan_id] = scan_info - storage_file_path.unlink() + pickler.remove_file(scan_id) def create_scan( self, @@ -261,7 +258,7 @@ def create_scan( scan_info['credentials'] = credentials scan_info['start_time'] = int(time.time()) - scan_info_to_pikle = { + scan_info_to_pickle = { 'target': target, 'options': options, 'vts': vts, @@ -270,7 +267,12 @@ def create_scan( if scan_id is None or scan_id == '': scan_id = str(uuid.uuid4()) - self.pickle_scan_info(scan_id, scan_info_to_pikle) + pickler = DataPickler(self.file_storage_dir) + try: + pickler.store_data(scan_id, scan_info_to_pickle) + except OspdCommandError as e: + logger.error(e) + return scan_info['scan_id'] = scan_id From 96bbc7bc0e4e4036f510b479f6b5f1888ba6ae26 Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Thu, 28 May 2020 14:40:02 +0200 Subject: [PATCH 13/31] Add integrity check for pickled data --- ospd/datapickler.py | 64 +++++++++++++++++++++++++++++++++++++-------- ospd/scan.py | 14 ++++++---- 2 files changed, 62 insertions(+), 16 deletions(-) diff --git a/ospd/datapickler.py b/ospd/datapickler.py index dd2430f4..733eda76 100644 --- a/ospd/datapickler.py +++ b/ospd/datapickler.py @@ -15,13 +15,16 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -""" Pikle Handler class +""" Pickle Handler class """ import logging import pickle +from hashlib import sha256 from pathlib import Path +from typing import Dict + from ospd.errors import OspdCommandError logger = logging.getLogger(__name__) @@ -36,7 +39,7 @@ def remove_file(self, filename): storage_file_path = Path(self._storage_path) / filename storage_file_path.unlink() - def store_data(self, filename, data_object): + def store_data(self, filename: str, data_object: Dict) -> str: """ Pickle a object and store it in a file named""" storage_file_path = Path(self._storage_path) / filename @@ -46,28 +49,67 @@ def store_data(self, filename, data_object): parent_dir.mkdir(parents=True, exist_ok=True) except Exception as e: # pylint: disable=broad-except raise OspdCommandError( - 'Not possible to store scan info for %s. %s' % (filename, e), + 'Not possible to access dir for %s. %s' % (filename, e), + 'start_scan', + ) + + try: + pickled_data = pickle.dumps(data_object) + except pickle.PicklingError as e: + raise OspdCommandError( + 'Not possible to pickle scan info for %s. %s' % (filename, e), 'start_scan', ) try: with storage_file_path.open('wb') as scan_info_f: - pickle.dump(data_object, scan_info_f) - except Exception as e: + scan_info_f.write(pickled_data) + except Exception as e: # pylint: disable=broad-except raise OspdCommandError( 'Not possible to store scan info for %s. %s' % (filename, e), 'start_scan', ) - def load_data(self, filename): - """ Unpikle stored data """ + return self._pickled_data_hash_generator(pickled_data) + + def load_data(self, filename: str, original_data_hash: str) -> Dict: + """ Unpickle the stored data in the filename. Perform an + intengrity check of the read data with the the hash generated + with the original data. + + Return: + Dictionary containing the scan info. None otherwise. + """ storage_file_path = Path(self._storage_path) / filename - unpikled_scan_info = None + pickled_data = None try: with storage_file_path.open('rb') as scan_info_f: - unpikled_scan_info = pickle.load(scan_info_f) + pickled_data = scan_info_f.read() except Exception as e: - logger.error('Not possible to load data from %s. %s', filename, e) + logger.error( + 'Not possible to read pickled data from %s. %s', filename, e + ) + + unpickled_scan_info = None + try: + unpickled_scan_info = pickle.loads(pickled_data) + except pickle.UnpicklingError as e: + logger.error( + 'Not possible to read pickled data from %s. %s', filename, e + ) + + pickled_scan_info_hash = self._pickled_data_hash_generator(pickled_data) + + if original_data_hash == pickled_scan_info_hash: + return unpickled_scan_info + + def _pickled_data_hash_generator(self, pickled_data): + """ Calculate the sha256 hash of a pickled data """ + if not pickled_data: + return + + hash_sha256 = sha256() + hash_sha256.update(pickle.dumps(pickled_data)) - return unpikled_scan_info + return hash_sha256.hexdigest() diff --git a/ospd/scan.py b/ospd/scan.py index 41a9da89..a2654074 100644 --- a/ospd/scan.py +++ b/ospd/scan.py @@ -213,8 +213,12 @@ def remove_file_pickled_scan_info(self, scan_id): def unpickle_scan_info(self, scan_id): """ Unpickle a stored scan_inf correspinding to the scan_id and store it in the scan_table """ + + scan_info = self.scans_table.get(scan_id) + scan_info_hash = scan_info.pop('scan_info_hash') + pickler = DataPickler(self.file_storage_dir) - unpickled_scan_info = pickler.load_data(scan_id) + unpickled_scan_info = pickler.load_data(scan_id, scan_info_hash) if not unpickled_scan_info: raise OspdCommandError( @@ -222,8 +226,6 @@ def unpickle_scan_info(self, scan_id): 'start_scan', ) - scan_info = self.scans_table.get(scan_id) - scan_info['results'] = list() scan_info['progress'] = 0 scan_info['target_progress'] = dict() @@ -268,13 +270,15 @@ def create_scan( scan_id = str(uuid.uuid4()) pickler = DataPickler(self.file_storage_dir) + scan_info_hash = None try: - pickler.store_data(scan_id, scan_info_to_pickle) + scan_info_hash = pickler.store_data(scan_id, scan_info_to_pickle) except OspdCommandError as e: - logger.error(e) + LOGGER.error(e) return scan_info['scan_id'] = scan_id + scan_info['scan_info_hash'] = scan_info_hash self.scans_table[scan_id] = scan_info return scan_id From e8730e2e39d30a4a406a86fccb5979e3f505315e Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Thu, 28 May 2020 15:53:24 +0200 Subject: [PATCH 14/31] Fix test --- tests/command/test_commands.py | 39 +++++++++------------------------- tests/helper.py | 1 + tests/test_scan_and_result.py | 38 +++++++++++++++++++++++++++++++-- tests/test_ssh_daemon.py | 7 ++++-- 4 files changed, 52 insertions(+), 33 deletions(-) diff --git a/tests/command/test_commands.py b/tests/command/test_commands.py index 2c658f46..1c9b1e11 100644 --- a/tests/command/test_commands.py +++ b/tests/command/test_commands.py @@ -119,6 +119,7 @@ def test_scan_with_vts(self, mock_create_process): # With one vt, without params response = et.fromstring(cmd.handle_xml(request)) + daemon.start_pending_scans() scan_id = response.findtext('id') vts_collection = daemon.get_scan_vts(scan_id) @@ -150,7 +151,7 @@ def test_scan_pop_vts(self): # With one vt, without params response = et.fromstring(cmd.handle_xml(request)) scan_id = response.findtext('id') - + daemon.start_pending_scans() vts_collection = daemon.get_scan_vts(scan_id) self.assertEqual(vts_collection, {'1.2.3.4': {}, 'vt_groups': []}) self.assertRaises(KeyError, daemon.get_scan_vts, scan_id) @@ -176,36 +177,13 @@ def test_scan_pop_ports(self): # With one vt, without params response = et.fromstring(cmd.handle_xml(request)) + daemon.start_pending_scans() scan_id = response.findtext('id') ports = daemon.scan_collection.get_ports(scan_id) self.assertEqual(ports, '80, 443') self.assertRaises(KeyError, daemon.scan_collection.get_ports, scan_id) - def test_is_new_scan_allowed_false(self): - daemon = DummyWrapper([]) - cmd = StartScan(daemon) - - cmd._daemon.scan_processes = { # pylint: disable=protected-access - 'a': 1, - 'b': 2, - } - daemon.max_scans = 1 - - self.assertFalse(cmd.is_new_scan_allowed()) - - def test_is_new_scan_allowed_true(self): - daemon = DummyWrapper([]) - cmd = StartScan(daemon) - - cmd._daemon.scan_processes = { # pylint: disable=protected-access - 'a': 1, - 'b': 2, - } - daemon.max_scans = 3 - - self.assertTrue(cmd.is_new_scan_allowed()) - @patch("ospd.ospd.create_process") def test_scan_without_vts(self, mock_create_process): daemon = DummyWrapper([]) @@ -223,12 +201,13 @@ def test_scan_without_vts(self, mock_create_process): '' '' ) + response = et.fromstring(cmd.handle_xml(request)) + daemon.start_pending_scans() scan_id = response.findtext('id') self.assertEqual(daemon.get_scan_vts(scan_id), {}) - daemon.start_pending_scans() assert_called(mock_create_process) def test_scan_with_vts_and_param_missing_vt_param_id(self): @@ -277,6 +256,8 @@ def test_scan_with_vts_and_param(self, mock_create_process): '' ) response = et.fromstring(cmd.handle_xml(request)) + daemon.start_pending_scans() + scan_id = response.findtext('id') self.assertEqual( @@ -303,6 +284,7 @@ def test_scan_with_vts_and_param_missing_vt_group_filter(self): '' '' ) + daemon.start_pending_scans() with self.assertRaises(OspdError): cmd.handle_xml(request) @@ -330,11 +312,11 @@ def test_scan_with_vts_and_param_with_vt_group_filter( '' ) response = et.fromstring(cmd.handle_xml(request)) + daemon.start_pending_scans() scan_id = response.findtext('id') self.assertEqual(daemon.get_scan_vts(scan_id), {'vt_groups': ['a']}) - daemon.start_pending_scans() assert_called(mock_create_process) @patch("ospd.ospd.create_process") @@ -375,6 +357,7 @@ def test_scan_use_legacy_target_and_port( ) response = et.fromstring(cmd.handle_xml(request)) + daemon.start_pending_scans() scan_id = response.findtext('id') self.assertIsNotNone(scan_id) @@ -382,8 +365,6 @@ def test_scan_use_legacy_target_and_port( self.assertEqual(daemon.get_scan_host(scan_id), 'localhost') self.assertEqual(daemon.get_scan_ports(scan_id), '22') - daemon.start_pending_scans() - assert_called(mock_logger.warning) assert_called(mock_create_process) diff --git a/tests/helper.py b/tests/helper.py index 8aec21a0..e718b7a7 100644 --- a/tests/helper.py +++ b/tests/helper.py @@ -64,6 +64,7 @@ def __init__(self, results, checkresult=True): self.results = results self.initialized = True self.scan_collection.data_manager = FakeDataManager() + self.scan_collection.file_storage_dir = '/tmp' def check(self): return self.checkresult diff --git a/tests/test_scan_and_result.py b/tests/test_scan_and_result.py index 12e4f61b..16464bde 100644 --- a/tests/test_scan_and_result.py +++ b/tests/test_scan_and_result.py @@ -79,6 +79,7 @@ class ScanTestCase(unittest.TestCase): def setUp(self): self.daemon = DummyWrapper([]) self.daemon.scan_collection.datamanager = FakeDataManager() + self.daemon.scan_collection.file_storage_dir = '/tmp' def test_get_default_scanner_params(self): fs = FakeStream() @@ -612,6 +613,7 @@ def test_get_scan_pop(self): '', fs, ) + self.daemon.start_pending_scans() response = fs.get_response() scan_id = response.findtext('id') @@ -654,6 +656,7 @@ def test_get_scan_pop_max_res(self): '', fs, ) + self.daemon.start_pending_scans() response = fs.get_response() scan_id = response.findtext('id') @@ -723,6 +726,7 @@ def test_target_with_credentials(self): '', fs, ) + self.daemon.start_pending_scans() response = fs.get_response() self.assertEqual(response.get('status'), '200') @@ -753,6 +757,8 @@ def test_scan_get_target(self): '', fs, ) + self.daemon.start_pending_scans() + response = fs.get_response() scan_id = response.findtext('id') @@ -776,6 +782,8 @@ def test_scan_get_target_options(self): '', fs, ) + self.daemon.start_pending_scans() + response = fs.get_response() scan_id = response.findtext('id') @@ -796,6 +804,7 @@ def test_progress(self): '', fs, ) + self.daemon.start_pending_scans() response = fs.get_response() scan_id = response.findtext('id') @@ -819,6 +828,8 @@ def test_sort_host_finished(self): '', fs, ) + self.daemon.start_pending_scans() + response = fs.get_response() scan_id = response.findtext('id') @@ -847,6 +858,7 @@ def test_calculate_progress_without_current_hosts(self): '', fs, ) + self.daemon.start_pending_scans() response = fs.get_response() scan_id = response.findtext('id') @@ -900,6 +912,8 @@ def test_get_scan_progress_xml(self): '', fs, ) + self.daemon.start_pending_scans() + response = fs.get_response() scan_id = response.findtext('id') @@ -991,6 +1005,8 @@ def test_scan_exists(self, mock_create_process, _mock_os): self.daemon.handle_command( cmd, fs, ) + self.daemon.start_pending_scans() + response = fs.get_response() status = response.get('status_text') self.assertEqual(status, 'Continue') @@ -1008,7 +1024,7 @@ def test_result_order(self): '', fs, ) - + self.daemon.start_pending_scans() response = fs.get_response() scan_id = response.findtext('id') @@ -1043,7 +1059,7 @@ def test_batch_result(self): '', fs, ) - + self.daemon.start_pending_scans() response = fs.get_response() scan_id = response.findtext('id') @@ -1065,3 +1081,21 @@ def test_batch_result(self): for idx, res in enumerate(results): att_dict = res.attrib self.assertEqual(hosts[idx], att_dict['name']) + + def test_is_new_scan_allowed_false(self): + self.daemon.scan_processes = { # pylint: disable=protected-access + 'a': 1, + 'b': 2, + } + self.daemon.max_scans = 1 + + self.assertFalse(self.daemon.is_new_scan_allowed()) + + def test_is_new_scan_allowed_true(self): + self.daemon.scan_processes = { # pylint: disable=protected-access + 'a': 1, + 'b': 2, + } + self.daemon.max_scans = 3 + + self.assertTrue(self.daemon.is_new_scan_allowed()) diff --git a/tests/test_ssh_daemon.py b/tests/test_ssh_daemon.py index 363bd449..658263ad 100644 --- a/tests/test_ssh_daemon.py +++ b/tests/test_ssh_daemon.py @@ -75,6 +75,7 @@ class DummyWrapper(OSPDaemonSimpleSSH): def __init__(self, niceness=10): super().__init__(niceness=niceness) self.scan_collection.data_manager = FakeDataManager() + self.scan_collection.file_storage_dir = '/tmp' def check(self): return True @@ -107,7 +108,7 @@ def test_run_command(self): dict(port=5, ssh_timeout=15, username_password='dummy:pw'), '', ) - + daemon.start_pending_scans() res = daemon.run_command(scanid, 'host.example.com', 'cat /etc/passwd') self.assertIsInstance(res, list) @@ -130,7 +131,7 @@ def test_run_command_legacy_credential(self): dict(port=5, ssh_timeout=15, username='dummy', password='pw'), '', ) - + daemon.start_pending_scans() res = daemon.run_command(scanid, 'host.example.com', 'cat /etc/passwd') self.assertIsInstance(res, list) @@ -164,6 +165,7 @@ def test_run_command_new_credential(self): dict(port=5, ssh_timeout=15), '', ) + daemon.start_pending_scans() res = daemon.run_command(scanid, 'host.example.com', 'cat /etc/passwd') self.assertIsInstance(res, list) @@ -186,6 +188,7 @@ def test_run_command_no_credential(self): dict(port=5, ssh_timeout=15), '', ) + daemon.start_pending_scans() with self.assertRaises(ValueError): daemon.run_command(scanid, 'host.example.com', 'cat /etc/passwd') From 1a045d4d249c1c9316b0bdc8c7e58541b7e4dc15 Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Thu, 28 May 2020 16:05:45 +0200 Subject: [PATCH 15/31] Make pylint happy --- ospd/datapickler.py | 2 +- ospd/ospd.py | 2 ++ ospd/scan.py | 2 -- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ospd/datapickler.py b/ospd/datapickler.py index 733eda76..1dfc9735 100644 --- a/ospd/datapickler.py +++ b/ospd/datapickler.py @@ -86,7 +86,7 @@ def load_data(self, filename: str, original_data_hash: str) -> Dict: try: with storage_file_path.open('rb') as scan_info_f: pickled_data = scan_info_f.read() - except Exception as e: + except Exception as e: # pylint: disable=broad-except logger.error( 'Not possible to read pickled data from %s. %s', filename, e ) diff --git a/ospd/ospd.py b/ospd/ospd.py index 86952671..385e8a50 100644 --- a/ospd/ospd.py +++ b/ospd/ospd.py @@ -41,6 +41,8 @@ import defusedxml.ElementTree as secET +import psutil + from deprecated import deprecated from ospd import __version__ diff --git a/ospd/scan.py b/ospd/scan.py index a2654074..c73298b9 100644 --- a/ospd/scan.py +++ b/ospd/scan.py @@ -19,9 +19,7 @@ import multiprocessing import time import uuid -import pickle -from pathlib import Path from collections import OrderedDict from enum import Enum from typing import List, Any, Dict, Iterator, Optional, Iterable, Union From 829abf8feb82c763686785b55f7790a589e654ca Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Thu, 28 May 2020 17:04:51 +0200 Subject: [PATCH 16/31] Improve error handling when remove the pickled data file. --- ospd/datapickler.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/ospd/datapickler.py b/ospd/datapickler.py index 1dfc9735..fd698d06 100644 --- a/ospd/datapickler.py +++ b/ospd/datapickler.py @@ -37,7 +37,10 @@ def __init__(self, storage_path): def remove_file(self, filename): """ Remove the file containing a scan_info pickled object """ storage_file_path = Path(self._storage_path) / filename - storage_file_path.unlink() + try: + storage_file_path.unlink() + except Exception as e: # pylint: disable=broad-except + logger.error('Not possible to delete %s. %s', filename, e) def store_data(self, filename: str, data_object: Dict) -> str: """ Pickle a object and store it in a file named""" @@ -90,6 +93,7 @@ def load_data(self, filename: str, original_data_hash: str) -> Dict: logger.error( 'Not possible to read pickled data from %s. %s', filename, e ) + return unpickled_scan_info = None try: @@ -110,6 +114,6 @@ def _pickled_data_hash_generator(self, pickled_data): return hash_sha256 = sha256() - hash_sha256.update(pickle.dumps(pickled_data)) + hash_sha256.update(pickled_data) return hash_sha256.hexdigest() From db078a4091260288923750e58d6374a38223fe14 Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Fri, 29 May 2020 09:40:23 +0200 Subject: [PATCH 17/31] Add test for data pickler --- tests/test_datapickler.py | 123 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 123 insertions(+) create mode 100644 tests/test_datapickler.py diff --git a/tests/test_datapickler.py b/tests/test_datapickler.py new file mode 100644 index 00000000..c94603ed --- /dev/null +++ b/tests/test_datapickler.py @@ -0,0 +1,123 @@ +# Copyright (C) 2014-2020 Greenbone Networks GmbH +# +# SPDX-License-Identifier: AGPL-3.0-or-later +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import logging +import pickle + +from pathlib import Path +from hashlib import sha256 +from unittest import TestCase +from unittest.mock import Mock, patch + +from ospd.errors import OspdCommandError +from ospd.datapickler import DataPickler + +from .helper import assert_called + + +class DataPecklerTestCase(TestCase): + def test_store_data(self): + data = {'foo', 'bar'} + filename = 'scan_info_1' + pickled_data = pickle.dumps(data) + m = sha256() + m.update(pickled_data) + + data_pickler = DataPickler('/tmp') + ret = data_pickler.store_data(filename, data) + + self.assertEqual(ret, m.hexdigest()) + + data_pickler.remove_file(filename) + + def test_store_data_failed(self): + data = {'foo', 'bar'} + filename = 'scan_info_1' + pickled_data = pickle.dumps(data) + m = sha256() + m.update(pickled_data) + + data_pickler = DataPickler('/root') + + self.assertRaises( + OspdCommandError, data_pickler.store_data, filename, data + ) + + def test_load_data(self): + + data_pickler = DataPickler('/tmp') + + data = {'foo', 'bar'} + filename = 'scan_info_1' + pickled_data = pickle.dumps(data) + + m = sha256() + m.update(pickled_data) + pickled_data_hash = m.hexdigest() + + ret = data_pickler.store_data(filename, data) + self.assertEqual(ret, pickled_data_hash) + + original_data = data_pickler.load_data(filename, pickled_data_hash) + self.assertIsNotNone(original_data) + + self.assertIn('foo', original_data) + + @patch("ospd.datapickler.logger") + def test_remove_file_failed(self, mock_logger): + filename = 'inenxistent_file' + data_pickler = DataPickler('/root') + data_pickler.remove_file(filename) + + assert_called(mock_logger.error) + + @patch("ospd.datapickler.logger") + def test_load_data_no_file(self, mock_logger): + data = {'foo', 'bar'} + filename = 'scan_info_1' + data_pickler = DataPickler('/tmp') + + data_loaded = data_pickler.load_data(filename, "1234") + assert_called(mock_logger.error) + self.assertIsNone(data_loaded) + + data_pickler.remove_file(filename) + + def test_load_data_corrupted(self): + + data_pickler = DataPickler('/tmp') + + data = {'foo', 'bar'} + filename = 'scan_info_1' + pickled_data = pickle.dumps(data) + + m = sha256() + m.update(pickled_data) + pickled_data_hash = m.hexdigest() + + ret = data_pickler.store_data(filename, data) + self.assertEqual(ret, pickled_data_hash) + + # courrupt data + file_to_corrupt = Path(data_pickler._storage_path) / filename + with file_to_corrupt.open('ab') as f: + f.write(b'bar2') + + original_data = data_pickler.load_data(filename, pickled_data_hash) + self.assertIsNone(original_data) + + data_pickler.remove_file(filename) From caec16c4ddc108d464652dcac8e604bc21e87c11 Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Fri, 29 May 2020 11:35:17 +0200 Subject: [PATCH 18/31] Create the file with permission only for the owner. Add test. --- ospd/datapickler.py | 23 ++++++++++++++++++++++- tests/test_datapickler.py | 15 +++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/ospd/datapickler.py b/ospd/datapickler.py index fd698d06..e36d7f15 100644 --- a/ospd/datapickler.py +++ b/ospd/datapickler.py @@ -20,6 +20,7 @@ import logging import pickle +import os from hashlib import sha256 from pathlib import Path @@ -29,10 +30,26 @@ logger = logging.getLogger(__name__) +OWNER_ONLY_RW_PERMISSION = 0o600 + class DataPickler: def __init__(self, storage_path): self._storage_path = storage_path + self._storage_fd = None + + def _fd_opener(self, path, flags): + os.umask(0) + flags = os.O_CREAT | os.O_WRONLY + self._storage_fd = os.open(path, flags, mode=OWNER_ONLY_RW_PERMISSION) + return self._storage_fd + + def _fd_close(self): + try: + self._storage_fd.close() + self._storage_fd = None + except Exception: # pylint: disable=broad-except + pass def remove_file(self, filename): """ Remove the file containing a scan_info pickled object """ @@ -65,13 +82,17 @@ def store_data(self, filename: str, data_object: Dict) -> str: ) try: - with storage_file_path.open('wb') as scan_info_f: + with open( + str(storage_file_path), 'wb', opener=self._fd_opener + ) as scan_info_f: scan_info_f.write(pickled_data) except Exception as e: # pylint: disable=broad-except + self._fd_close() raise OspdCommandError( 'Not possible to store scan info for %s. %s' % (filename, e), 'start_scan', ) + self._fd_close() return self._pickled_data_hash_generator(pickled_data) diff --git a/tests/test_datapickler.py b/tests/test_datapickler.py index c94603ed..c4818730 100644 --- a/tests/test_datapickler.py +++ b/tests/test_datapickler.py @@ -57,6 +57,21 @@ def test_store_data_failed(self): OspdCommandError, data_pickler.store_data, filename, data ) + def test_store_data_check_permission(self): + OWNER_ONLY_RW_PERMISSION = '0o100600' + data = {'foo', 'bar'} + filename = 'scan_info_1' + + data_pickler = DataPickler('/tmp') + data_pickler.store_data(filename, data) + + file_path = Path(data_pickler._storage_path) / filename + self.assertEqual( + oct(file_path.stat().st_mode), OWNER_ONLY_RW_PERMISSION + ) + + data_pickler.remove_file(filename) + def test_load_data(self): data_pickler = DataPickler('/tmp') From 35e46df1ff8b4f936fa9bd96abb3b14e8919b9e7 Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Fri, 29 May 2020 15:40:41 +0200 Subject: [PATCH 19/31] More pylint fixes --- tests/test_datapickler.py | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/tests/test_datapickler.py b/tests/test_datapickler.py index c4818730..6dd3e2a1 100644 --- a/tests/test_datapickler.py +++ b/tests/test_datapickler.py @@ -15,13 +15,12 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -import logging import pickle from pathlib import Path from hashlib import sha256 from unittest import TestCase -from unittest.mock import Mock, patch +from unittest.mock import patch from ospd.errors import OspdCommandError from ospd.datapickler import DataPickler @@ -34,22 +33,19 @@ def test_store_data(self): data = {'foo', 'bar'} filename = 'scan_info_1' pickled_data = pickle.dumps(data) - m = sha256() - m.update(pickled_data) + tmp_hash = sha256() + tmp_hash.update(pickled_data) data_pickler = DataPickler('/tmp') ret = data_pickler.store_data(filename, data) - self.assertEqual(ret, m.hexdigest()) + self.assertEqual(ret, tmp_hash.hexdigest()) data_pickler.remove_file(filename) def test_store_data_failed(self): data = {'foo', 'bar'} filename = 'scan_info_1' - pickled_data = pickle.dumps(data) - m = sha256() - m.update(pickled_data) data_pickler = DataPickler('/root') @@ -58,14 +54,17 @@ def test_store_data_failed(self): ) def test_store_data_check_permission(self): - OWNER_ONLY_RW_PERMISSION = '0o100600' + OWNER_ONLY_RW_PERMISSION = '0o100600' # pylint: disable=invalid-name data = {'foo', 'bar'} filename = 'scan_info_1' data_pickler = DataPickler('/tmp') data_pickler.store_data(filename, data) - file_path = Path(data_pickler._storage_path) / filename + file_path = ( + Path(data_pickler._storage_path) # pylint: disable=protected-access + / filename + ) self.assertEqual( oct(file_path.stat().st_mode), OWNER_ONLY_RW_PERMISSION ) @@ -80,9 +79,9 @@ def test_load_data(self): filename = 'scan_info_1' pickled_data = pickle.dumps(data) - m = sha256() - m.update(pickled_data) - pickled_data_hash = m.hexdigest() + tmp_hash = sha256() + tmp_hash.update(pickled_data) + pickled_data_hash = tmp_hash.hexdigest() ret = data_pickler.store_data(filename, data) self.assertEqual(ret, pickled_data_hash) @@ -102,7 +101,6 @@ def test_remove_file_failed(self, mock_logger): @patch("ospd.datapickler.logger") def test_load_data_no_file(self, mock_logger): - data = {'foo', 'bar'} filename = 'scan_info_1' data_pickler = DataPickler('/tmp') @@ -120,15 +118,18 @@ def test_load_data_corrupted(self): filename = 'scan_info_1' pickled_data = pickle.dumps(data) - m = sha256() - m.update(pickled_data) - pickled_data_hash = m.hexdigest() + tmp_hash = sha256() + tmp_hash.update(pickled_data) + pickled_data_hash = tmp_hash.hexdigest() ret = data_pickler.store_data(filename, data) self.assertEqual(ret, pickled_data_hash) # courrupt data - file_to_corrupt = Path(data_pickler._storage_path) / filename + file_to_corrupt = ( + Path(data_pickler._storage_path) # pylint: disable=protected-access + / filename + ) with file_to_corrupt.open('ab') as f: f.write(b'bar2') From a0495d252afbe8e94683907bcaa7e7c799f3433e Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Fri, 29 May 2020 16:33:25 +0200 Subject: [PATCH 20/31] Empty results element in get_scan response for pending scans --- ospd/ospd.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ospd/ospd.py b/ospd/ospd.py index 385e8a50..1c147696 100644 --- a/ospd/ospd.py +++ b/ospd/ospd.py @@ -778,6 +778,7 @@ def get_scan_xml( response = Element('scan') detailed = False progress = False + response.append(Element('results')) else: target = self.get_scan_host(scan_id) scan_progress = self.get_scan_progress(scan_id) From 27dbfd71075992aa69986604d6e4d17db4aff235 Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Tue, 2 Jun 2020 14:08:37 +0200 Subject: [PATCH 21/31] Rename Pending status to Queued --- ospd/ospd.py | 14 +++++++------- ospd/scan.py | 4 ++-- tests/command/test_commands.py | 24 ++++++++++++------------ tests/test_scan_and_result.py | 31 ++++++++++++++++--------------- tests/test_ssh_daemon.py | 8 ++++---- 5 files changed, 41 insertions(+), 40 deletions(-) diff --git a/ospd/ospd.py b/ospd/ospd.py index 1c147696..25f36798 100644 --- a/ospd/ospd.py +++ b/ospd/ospd.py @@ -371,7 +371,7 @@ def process_targets_element(scanner_target) -> List: def stop_scan(self, scan_id: str) -> None: if ( scan_id in self.scan_collection.ids_iterator() - and self.get_scan_status(scan_id) == ScanStatus.PENDING + and self.get_scan_status(scan_id) == ScanStatus.QUEUED ): self.scan_collection.remove_file_pickled_scan_info(scan_id) self.set_scan_status(scan_id, ScanStatus.STOPPED) @@ -769,7 +769,7 @@ def get_scan_xml( if not scan_id: return Element('scan') - if self.get_scan_status(scan_id) == ScanStatus.PENDING: + if self.get_scan_status(scan_id) == ScanStatus.QUEUED: target = '' scan_progress = 0 status = self.get_scan_status(scan_id) @@ -1199,13 +1199,13 @@ def run(self) -> None: time.sleep(SCHEDULER_CHECK_PERIOD) self.scheduler() self.clean_forgotten_scans() - self.start_pending_scans() + self.start_queued_scans() self.wait_for_children() except KeyboardInterrupt: logger.info("Received Ctrl-C shutting-down ...") - def start_pending_scans(self): - """ Starts a pending scan if it is allowed """ + def start_queued_scans(self): + """ Starts a queued scan if it is allowed """ for scan_id in self.scan_collection.ids_iterator(): if not self.is_new_scan_allowed(): @@ -1220,7 +1220,7 @@ def start_pending_scans(self): ) return - if self.get_scan_status(scan_id) == ScanStatus.PENDING: + if self.get_scan_status(scan_id) == ScanStatus.QUEUED: try: self.scan_collection.unpickle_scan_info(scan_id) except OspdCommandError as e: @@ -1336,7 +1336,7 @@ def clean_forgotten_scans(self) -> None: def check_scan_process(self, scan_id: str) -> None: """ Check the scan's process, and terminate the scan if not alive. """ - if self.get_scan_status(scan_id) == ScanStatus.PENDING: + if self.get_scan_status(scan_id) == ScanStatus.QUEUED: return scan_process = self.scan_processes.get(scan_id) diff --git a/ospd/scan.py b/ospd/scan.py index c73298b9..e8aae764 100644 --- a/ospd/scan.py +++ b/ospd/scan.py @@ -34,7 +34,7 @@ class ScanStatus(Enum): """Scan status. """ - PENDING = 0 + QUEUED = 0 INIT = 1 RUNNING = 2 STOPPED = 3 @@ -254,7 +254,7 @@ def create_scan( credentials = target.pop('credentials') scan_info = self.data_manager.dict() # type: Dict - scan_info['status'] = ScanStatus.PENDING + scan_info['status'] = ScanStatus.QUEUED scan_info['credentials'] = credentials scan_info['start_time'] = int(time.time()) diff --git a/tests/command/test_commands.py b/tests/command/test_commands.py index 1c9b1e11..b0f1ac32 100644 --- a/tests/command/test_commands.py +++ b/tests/command/test_commands.py @@ -119,14 +119,14 @@ def test_scan_with_vts(self, mock_create_process): # With one vt, without params response = et.fromstring(cmd.handle_xml(request)) - daemon.start_pending_scans() + daemon.start_queued_scans() scan_id = response.findtext('id') vts_collection = daemon.get_scan_vts(scan_id) self.assertEqual(vts_collection, {'1.2.3.4': {}, 'vt_groups': []}) self.assertNotEqual(vts_collection, {'1.2.3.6': {}}) - daemon.start_pending_scans() + daemon.start_queued_scans() assert_called(mock_create_process) def test_scan_pop_vts(self): @@ -151,7 +151,7 @@ def test_scan_pop_vts(self): # With one vt, without params response = et.fromstring(cmd.handle_xml(request)) scan_id = response.findtext('id') - daemon.start_pending_scans() + daemon.start_queued_scans() vts_collection = daemon.get_scan_vts(scan_id) self.assertEqual(vts_collection, {'1.2.3.4': {}, 'vt_groups': []}) self.assertRaises(KeyError, daemon.get_scan_vts, scan_id) @@ -177,7 +177,7 @@ def test_scan_pop_ports(self): # With one vt, without params response = et.fromstring(cmd.handle_xml(request)) - daemon.start_pending_scans() + daemon.start_queued_scans() scan_id = response.findtext('id') ports = daemon.scan_collection.get_ports(scan_id) @@ -203,7 +203,7 @@ def test_scan_without_vts(self, mock_create_process): ) response = et.fromstring(cmd.handle_xml(request)) - daemon.start_pending_scans() + daemon.start_queued_scans() scan_id = response.findtext('id') self.assertEqual(daemon.get_scan_vts(scan_id), {}) @@ -256,7 +256,7 @@ def test_scan_with_vts_and_param(self, mock_create_process): '' ) response = et.fromstring(cmd.handle_xml(request)) - daemon.start_pending_scans() + daemon.start_queued_scans() scan_id = response.findtext('id') @@ -264,7 +264,7 @@ def test_scan_with_vts_and_param(self, mock_create_process): daemon.get_scan_vts(scan_id), {'1234': {'ABC': '200'}, 'vt_groups': []}, ) - daemon.start_pending_scans() + daemon.start_queued_scans() assert_called(mock_create_process) def test_scan_with_vts_and_param_missing_vt_group_filter(self): @@ -284,7 +284,7 @@ def test_scan_with_vts_and_param_missing_vt_group_filter(self): '' '' ) - daemon.start_pending_scans() + daemon.start_queued_scans() with self.assertRaises(OspdError): cmd.handle_xml(request) @@ -312,7 +312,7 @@ def test_scan_with_vts_and_param_with_vt_group_filter( '' ) response = et.fromstring(cmd.handle_xml(request)) - daemon.start_pending_scans() + daemon.start_queued_scans() scan_id = response.findtext('id') self.assertEqual(daemon.get_scan_vts(scan_id), {'vt_groups': ['a']}) @@ -337,7 +337,7 @@ def test_scan_ignore_multi_target(self, mock_logger, mock_create_process): ) cmd.handle_xml(request) - daemon.start_pending_scans() + daemon.start_queued_scans() assert_called(mock_logger.warning) assert_called(mock_create_process) @@ -357,7 +357,7 @@ def test_scan_use_legacy_target_and_port( ) response = et.fromstring(cmd.handle_xml(request)) - daemon.start_pending_scans() + daemon.start_queued_scans() scan_id = response.findtext('id') self.assertIsNotNone(scan_id) @@ -393,7 +393,7 @@ def test_stop_scan(self, mock_create_process, mock_os): daemon.handle_command(request, fs) response = fs.get_response() - daemon.start_pending_scans() + daemon.start_queued_scans() assert_called(mock_create_process) assert_called(mock_process.start) diff --git a/tests/test_scan_and_result.py b/tests/test_scan_and_result.py index 16464bde..8eb13716 100644 --- a/tests/test_scan_and_result.py +++ b/tests/test_scan_and_result.py @@ -508,7 +508,7 @@ def test_clean_forgotten_scans(self): finished = False - self.daemon.start_pending_scans() + self.daemon.start_queued_scans() while not finished: fs = FakeStream() self.daemon.handle_command( @@ -565,7 +565,7 @@ def test_scan_with_error(self): response = fs.get_response() scan_id = response.findtext('id') finished = False - self.daemon.start_pending_scans() + self.daemon.start_queued_scans() self.daemon.add_scan_error( scan_id, host='a', value='something went wrong' ) @@ -613,7 +613,7 @@ def test_get_scan_pop(self): '', fs, ) - self.daemon.start_pending_scans() + self.daemon.start_queued_scans() response = fs.get_response() scan_id = response.findtext('id') @@ -656,7 +656,7 @@ def test_get_scan_pop_max_res(self): '', fs, ) - self.daemon.start_pending_scans() + self.daemon.start_queued_scans() response = fs.get_response() scan_id = response.findtext('id') @@ -726,7 +726,7 @@ def test_target_with_credentials(self): '', fs, ) - self.daemon.start_pending_scans() + self.daemon.start_queued_scans() response = fs.get_response() self.assertEqual(response.get('status'), '200') @@ -757,7 +757,7 @@ def test_scan_get_target(self): '', fs, ) - self.daemon.start_pending_scans() + self.daemon.start_queued_scans() response = fs.get_response() scan_id = response.findtext('id') @@ -782,7 +782,7 @@ def test_scan_get_target_options(self): '', fs, ) - self.daemon.start_pending_scans() + self.daemon.start_queued_scans() response = fs.get_response() @@ -804,7 +804,7 @@ def test_progress(self): '', fs, ) - self.daemon.start_pending_scans() + self.daemon.start_queued_scans() response = fs.get_response() scan_id = response.findtext('id') @@ -828,7 +828,7 @@ def test_sort_host_finished(self): '', fs, ) - self.daemon.start_pending_scans() + self.daemon.start_queued_scans() response = fs.get_response() @@ -858,7 +858,7 @@ def test_calculate_progress_without_current_hosts(self): '', fs, ) - self.daemon.start_pending_scans() + self.daemon.start_queued_scans() response = fs.get_response() scan_id = response.findtext('id') @@ -890,6 +890,7 @@ def test_get_scan_without_scanid(self): '', fs, ) + self.daemon.start_queued_scans() fs = FakeStream() self.assertRaises( @@ -912,7 +913,7 @@ def test_get_scan_progress_xml(self): '', fs, ) - self.daemon.start_pending_scans() + self.daemon.start_queued_scans() response = fs.get_response() scan_id = response.findtext('id') @@ -984,7 +985,7 @@ def test_scan_exists(self, mock_create_process, _mock_os): status = response.get('status_text') self.assertEqual(status, 'OK') - self.daemon.start_pending_scans() + self.daemon.start_queued_scans() assert_called(mock_create_process) assert_called(mock_process.start) @@ -1005,7 +1006,7 @@ def test_scan_exists(self, mock_create_process, _mock_os): self.daemon.handle_command( cmd, fs, ) - self.daemon.start_pending_scans() + self.daemon.start_queued_scans() response = fs.get_response() status = response.get('status_text') @@ -1024,7 +1025,7 @@ def test_result_order(self): '', fs, ) - self.daemon.start_pending_scans() + self.daemon.start_queued_scans() response = fs.get_response() scan_id = response.findtext('id') @@ -1059,7 +1060,7 @@ def test_batch_result(self): '', fs, ) - self.daemon.start_pending_scans() + self.daemon.start_queued_scans() response = fs.get_response() scan_id = response.findtext('id') diff --git a/tests/test_ssh_daemon.py b/tests/test_ssh_daemon.py index 658263ad..6e0b5cba 100644 --- a/tests/test_ssh_daemon.py +++ b/tests/test_ssh_daemon.py @@ -108,7 +108,7 @@ def test_run_command(self): dict(port=5, ssh_timeout=15, username_password='dummy:pw'), '', ) - daemon.start_pending_scans() + daemon.start_queued_scans() res = daemon.run_command(scanid, 'host.example.com', 'cat /etc/passwd') self.assertIsInstance(res, list) @@ -131,7 +131,7 @@ def test_run_command_legacy_credential(self): dict(port=5, ssh_timeout=15, username='dummy', password='pw'), '', ) - daemon.start_pending_scans() + daemon.start_queued_scans() res = daemon.run_command(scanid, 'host.example.com', 'cat /etc/passwd') self.assertIsInstance(res, list) @@ -165,7 +165,7 @@ def test_run_command_new_credential(self): dict(port=5, ssh_timeout=15), '', ) - daemon.start_pending_scans() + daemon.start_queued_scans() res = daemon.run_command(scanid, 'host.example.com', 'cat /etc/passwd') self.assertIsInstance(res, list) @@ -188,7 +188,7 @@ def test_run_command_no_credential(self): dict(port=5, ssh_timeout=15), '', ) - daemon.start_pending_scans() + daemon.start_queued_scans() with self.assertRaises(ValueError): daemon.run_command(scanid, 'host.example.com', 'cat /etc/passwd') From 5e155a8351e70c35bd26278fb3c97afd63ce00a7 Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Tue, 2 Jun 2020 15:44:42 +0200 Subject: [PATCH 22/31] Check if there is enough free memory before starting a scan. If there is no enough memory, the scan is kept in the queue. A minimum free memory value in MB must be given as option, otherwise no free memory check will be performed. --- ospd/ospd.py | 15 +++++++-------- ospd/parser.py | 19 ++++++------------- tests/helper.py | 5 +++++ tests/test_scan_and_result.py | 24 +++++++++++++++++++++++- 4 files changed, 41 insertions(+), 22 deletions(-) diff --git a/ospd/ospd.py b/ospd/ospd.py index 25f36798..e53ec778 100644 --- a/ospd/ospd.py +++ b/ospd/ospd.py @@ -118,7 +118,7 @@ def __init__( customvtfilter=None, storage=None, max_scans=0, - check_free_memory=False, + min_free_mem_scan_queue=0, file_storage_dir='/var/run/ospd', **kwargs ): # pylint: disable=unused-argument @@ -141,7 +141,7 @@ def __init__( self.initialized = None # Set after initialization finished self.max_scans = max_scans - self.check_free_memory = check_free_memory + self.min_free_mem_scan_queue = min_free_mem_scan_queue self.scaninfo_store_time = kwargs.get('scaninfo_store_time') @@ -1214,7 +1214,10 @@ def start_queued_scans(self): ) return - if self.check_free_memory and not self.is_enough_free_memory(): + if ( + self.min_free_mem_scan_queue + and not self.is_enough_free_memory() + ): logger.debug( 'Not possible to run a new scan. Not enough free memory.' ) @@ -1253,13 +1256,9 @@ def is_enough_free_memory(self) -> bool: Return: True if there is enough memory for a new scan. """ - - ps_process = psutil.Process() - proc_memory = ps_process.memory_info().rss - free_mem = psutil.virtual_memory().free - if free_mem > (4 * proc_memory): + if (free_mem / (1024 * 1024)) > self.min_free_mem_scan_queue: return True return False diff --git a/ospd/parser.py b/ospd/parser.py index 8560c391..be0f7f0a 100644 --- a/ospd/parser.py +++ b/ospd/parser.py @@ -38,6 +38,7 @@ DEFAULT_STREAM_TIMEOUT = 10 # ten seconds DEFAULT_SCANINFO_STORE_TIME = 0 # in hours DEFAULT_MAX_SCAN = 0 # 0 = disable +DEFAULT_MIN_FREE_MEM_SCAN_QUEUE = 0 # 0 = Disable ParserType = argparse.ArgumentParser Arguments = argparse.Namespace @@ -167,11 +168,11 @@ def __init__(self, description: str) -> None: 'Default %(default)s, disabled', ) parser.add_argument( - '--check-free-memory', - default=False, - type=self.str2bool, - help='Check if there is enough free memory to run the scan. ' - 'This is an experimental feature. ' + '--min-free-mem-scan-queue', + default=DEFAULT_MIN_FREE_MEM_SCAN_QUEUE, + type=int, + help='Minimum free memory in MB required to run the scan. ' + 'If no enough free memory is available, the scan queued. ' 'Default %(default)s, disabled', ) @@ -187,14 +188,6 @@ def network_port(self, string: str) -> int: ) return value - def str2bool(self, value: Union[int, str, bool]) -> bool: - """ Check if provided string is a valid bool value. """ - if isinstance(value, bool): - return value - if value.lower() in ('yes', 'true', 't', 'y', '1'): - return True - return False - def log_level(self, string: str) -> int: """ Check if provided string is a valid log level. """ diff --git a/tests/helper.py b/tests/helper.py index e718b7a7..29a255ea 100644 --- a/tests/helper.py +++ b/tests/helper.py @@ -38,6 +38,11 @@ def assert_called(mock: Mock): raise AssertionError(msg) +class FakePsutil: + def __init__(self, free=None): + self.free = free + + class FakeStream: def __init__(self): self.response = b'' diff --git a/tests/test_scan_and_result.py b/tests/test_scan_and_result.py index 8eb13716..c62ff63d 100644 --- a/tests/test_scan_and_result.py +++ b/tests/test_scan_and_result.py @@ -32,7 +32,13 @@ from ospd.resultlist import ResultList from ospd.errors import OspdCommandError -from .helper import DummyWrapper, assert_called, FakeStream, FakeDataManager +from .helper import ( + DummyWrapper, + assert_called, + FakeStream, + FakeDataManager, + FakePsutil, +) class FakeStartProcess: @@ -1100,3 +1106,19 @@ def test_is_new_scan_allowed_true(self): self.daemon.max_scans = 3 self.assertTrue(self.daemon.is_new_scan_allowed()) + + @patch("ospd.ospd.psutil") + def test_free_memory_true(self, mock_psutil): + self.daemon.min_free_mem_scan_queue = 1000 + # 1.5 GB free + mock_psutil.virtual_memory.return_value = FakePsutil(free=1500000000) + + self.assertTrue(self.daemon.is_enough_free_memory()) + + @patch("ospd.ospd.psutil") + def test_free_memory_false(self, mock_psutil): + self.daemon.min_free_mem_scan_queue = 2000 + # 1.5 GB free + mock_psutil.virtual_memory.return_value = FakePsutil(free=1500000000) + + self.assertFalse(self.daemon.is_enough_free_memory()) From c16d416e10dad4a7509650b3505eb706f60a09aa Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Tue, 2 Jun 2020 16:40:33 +0200 Subject: [PATCH 23/31] Cleanup queued scans if the daemon is killed. This will remove the pickled scan info files. --- ospd/main.py | 17 ++++++++++++++--- ospd/scan.py | 6 ++++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/ospd/main.py b/ospd/main.py index 03554da9..f097baf3 100644 --- a/ospd/main.py +++ b/ospd/main.py @@ -34,6 +34,7 @@ from ospd.parser import create_parser, ParserType from ospd.server import TlsServer, UnixSocketServer, BaseServer + COPYRIGHT = """Copyright (C) 2014, 2015, 2018, 2019 Greenbone Networks GmbH License GPLv2+: GNU GPL version 2 or later This is free software: you are free to change and redistribute it. @@ -107,12 +108,18 @@ def init_logging( def exit_cleanup( - pidfile: str, server: BaseServer, _signum=None, _frame=None + pidfile: str, + server: BaseServer, + daemon: OSPDaemon, + _signum=None, + _frame=None, ) -> None: """ Removes the pidfile before ending the daemon. """ signal.signal(signal.SIGINT, signal.SIG_IGN) pidpath = Path(pidfile) + daemon.scan_collection.clean_up_pickled_scan_info() + if not pidpath.is_file(): return @@ -174,8 +181,12 @@ def main( sys.exit() # Set signal handler and cleanup - atexit.register(exit_cleanup, pidfile=args.pid_file, server=server) - signal.signal(signal.SIGTERM, partial(exit_cleanup, args.pid_file, server)) + atexit.register( + exit_cleanup, pidfile=args.pid_file, server=server, daemon=daemon + ) + signal.signal( + signal.SIGTERM, partial(exit_cleanup, args.pid_file, server, daemon) + ) if not daemon.check(): return 1 diff --git a/ospd/scan.py b/ospd/scan.py index e8aae764..ccb81e7b 100644 --- a/ospd/scan.py +++ b/ospd/scan.py @@ -204,6 +204,12 @@ def ids_iterator(self) -> Iterator[str]: return iter(self.scans_table.keys()) + def clean_up_pickled_scan_info(self): + """ Remove files of pickled scan info """ + for scan_id in self.ids_iterator(): + if self.get_status(scan_id) == ScanStatus.QUEUED: + self.remove_file_pickled_scan_info(scan_id) + def remove_file_pickled_scan_info(self, scan_id): pickler = DataPickler(self.file_storage_dir) pickler.remove_file(scan_id) From 263e69fd796ee05e73673718f0287f67654e04db Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Tue, 2 Jun 2020 16:54:05 +0200 Subject: [PATCH 24/31] Make pylint happy again --- ospd/parser.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ospd/parser.py b/ospd/parser.py index be0f7f0a..f375f383 100644 --- a/ospd/parser.py +++ b/ospd/parser.py @@ -18,7 +18,6 @@ import argparse import logging from pathlib import Path -from typing import Union from ospd.config import Config From 3c9f2259409fa62bda2b1c7b98b33ee07df9fbf1 Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Wed, 3 Jun 2020 13:23:50 +0200 Subject: [PATCH 25/31] Add option and check for max count of queued scans. If the max is reached, the next scans are rejected. --- ospd/command/command.py | 9 +++++++++ ospd/ospd.py | 10 ++++++++++ ospd/parser.py | 9 +++++++++ 3 files changed, 28 insertions(+) diff --git a/ospd/command/command.py b/ospd/command/command.py index 7903d804..9e01b854 100644 --- a/ospd/command/command.py +++ b/ospd/command/command.py @@ -478,6 +478,15 @@ def handle_xml(self, xml: Element) -> bytes: Response string for command. """ + if ( + self._daemon.max_queued_scans + and self._daemon.get_count_queued_scans() + >= self._daemon.max_queued_scans + ): + raise OspdCommandError( + 'Maximum number of queued scans reached.', 'start_scan' + ) + target_str = xml.get('target') ports_str = xml.get('ports') diff --git a/ospd/ospd.py b/ospd/ospd.py index e53ec778..0899ea73 100644 --- a/ospd/ospd.py +++ b/ospd/ospd.py @@ -120,6 +120,7 @@ def __init__( max_scans=0, min_free_mem_scan_queue=0, file_storage_dir='/var/run/ospd', + max_queued_scans=0, **kwargs ): # pylint: disable=unused-argument """ Initializes the daemon's internal data. """ @@ -142,6 +143,7 @@ def __init__( self.max_scans = max_scans self.min_free_mem_scan_queue = min_free_mem_scan_queue + self.max_queued_scans = max_queued_scans self.scaninfo_store_time = kwargs.get('scaninfo_store_time') @@ -1357,6 +1359,14 @@ def check_scan_process(self, scan_id: str) -> None: elif progress == PROGRESS_FINISHED: scan_process.join(0) + def get_count_queued_scans(self) -> int: + """ Get the amount of scans with queued status """ + count = 0 + for scan_id in self.scan_collection.ids_iterator(): + if self.get_scan_status(scan_id) == ScanStatus.QUEUED: + count += 1 + return count + def get_scan_progress(self, scan_id: str) -> int: """ Gives a scan's current progress value. """ return self.scan_collection.get_progress(scan_id) diff --git a/ospd/parser.py b/ospd/parser.py index f375f383..88150c1c 100644 --- a/ospd/parser.py +++ b/ospd/parser.py @@ -38,6 +38,7 @@ DEFAULT_SCANINFO_STORE_TIME = 0 # in hours DEFAULT_MAX_SCAN = 0 # 0 = disable DEFAULT_MIN_FREE_MEM_SCAN_QUEUE = 0 # 0 = Disable +DEFAULT_MAX_QUEUED_SCANS = 0 # 0 = Disable ParserType = argparse.ArgumentParser Arguments = argparse.Namespace @@ -174,6 +175,14 @@ def __init__(self, description: str) -> None: 'If no enough free memory is available, the scan queued. ' 'Default %(default)s, disabled', ) + parser.add_argument( + '--max-queued-scans', + default=DEFAULT_MAX_QUEUED_SCANS, + type=int, + help='Maximum number allowed of queued scans before ' + 'starting to reject new scans. ' + 'Default %(default)s, disabled', + ) self.parser = parser From e79c8282d5bc7e68d5be31504ee8782e8d87e5c7 Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Wed, 3 Jun 2020 13:25:08 +0200 Subject: [PATCH 26/31] Add test for max queued scans --- tests/command/test_commands.py | 25 +++++++++++++++++++++++++ tests/test_scan_and_result.py | 18 ++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/tests/command/test_commands.py b/tests/command/test_commands.py index b0f1ac32..b7f12516 100644 --- a/tests/command/test_commands.py +++ b/tests/command/test_commands.py @@ -341,6 +341,31 @@ def test_scan_ignore_multi_target(self, mock_logger, mock_create_process): assert_called(mock_logger.warning) assert_called(mock_create_process) + def test_max_queued_scans_reached(self): + daemon = DummyWrapper([]) + daemon.max_queued_scans = 1 + cmd = StartScan(daemon) + request = et.fromstring( + '' + '' + '' + 'localhosts' + '22' + '' + '' + '' + '' + ) + + # create first scan + response = et.fromstring(cmd.handle_xml(request)) + scan_id_1 = response.findtext('id') + + with self.assertRaises(OspdCommandError): + cmd.handle_xml(request) + + daemon.scan_collection.remove_file_pickled_scan_info(scan_id_1) + @patch("ospd.ospd.create_process") @patch("ospd.command.command.logger") def test_scan_use_legacy_target_and_port( diff --git a/tests/test_scan_and_result.py b/tests/test_scan_and_result.py index c62ff63d..9c99bfa8 100644 --- a/tests/test_scan_and_result.py +++ b/tests/test_scan_and_result.py @@ -1122,3 +1122,21 @@ def test_free_memory_false(self, mock_psutil): mock_psutil.virtual_memory.return_value = FakePsutil(free=1500000000) self.assertFalse(self.daemon.is_enough_free_memory()) + + def test_count_queued_scans(self): + fs = FakeStream() + self.daemon.handle_command( + '' + '' + '' + '' + 'localhosts,192.168.0.0/24' + '80,443' + '' + '', + fs, + ) + + self.assertEqual(self.daemon.get_count_queued_scans(), 1) + self.daemon.start_queued_scans() + self.assertEqual(self.daemon.get_count_queued_scans(), 0) From eb22217200703fe58acf1a8243c2b507721350fa Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Thu, 4 Jun 2020 11:06:12 +0200 Subject: [PATCH 27/31] Add typing to datapickler --- ospd/datapickler.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/ospd/datapickler.py b/ospd/datapickler.py index e36d7f15..67539897 100644 --- a/ospd/datapickler.py +++ b/ospd/datapickler.py @@ -24,7 +24,7 @@ from hashlib import sha256 from pathlib import Path -from typing import Dict +from typing import Dict, BinaryIO, Any from ospd.errors import OspdCommandError @@ -34,24 +34,24 @@ class DataPickler: - def __init__(self, storage_path): + def __init__(self, storage_path: str): self._storage_path = storage_path self._storage_fd = None - def _fd_opener(self, path, flags): + def _fd_opener(self, path: str, flags: int) -> BinaryIO: os.umask(0) flags = os.O_CREAT | os.O_WRONLY self._storage_fd = os.open(path, flags, mode=OWNER_ONLY_RW_PERMISSION) return self._storage_fd - def _fd_close(self): + def _fd_close(self) -> None: try: self._storage_fd.close() self._storage_fd = None except Exception: # pylint: disable=broad-except pass - def remove_file(self, filename): + def remove_file(self, filename: str) -> None: """ Remove the file containing a scan_info pickled object """ storage_file_path = Path(self._storage_path) / filename try: @@ -59,7 +59,7 @@ def remove_file(self, filename): except Exception as e: # pylint: disable=broad-except logger.error('Not possible to delete %s. %s', filename, e) - def store_data(self, filename: str, data_object: Dict) -> str: + def store_data(self, filename: str, data_object: Any) -> str: """ Pickle a object and store it in a file named""" storage_file_path = Path(self._storage_path) / filename @@ -96,7 +96,7 @@ def store_data(self, filename: str, data_object: Dict) -> str: return self._pickled_data_hash_generator(pickled_data) - def load_data(self, filename: str, original_data_hash: str) -> Dict: + def load_data(self, filename: str, original_data_hash: str) -> Any: """ Unpickle the stored data in the filename. Perform an intengrity check of the read data with the the hash generated with the original data. @@ -129,7 +129,7 @@ def load_data(self, filename: str, original_data_hash: str) -> Dict: if original_data_hash == pickled_scan_info_hash: return unpickled_scan_info - def _pickled_data_hash_generator(self, pickled_data): + def _pickled_data_hash_generator(self, pickled_data: bytes) -> str: """ Calculate the sha256 hash of a pickled data """ if not pickled_data: return From a9c453012ea5b862053afacfe73f7c8fe1af1a31 Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Thu, 4 Jun 2020 11:36:03 +0200 Subject: [PATCH 28/31] Don't access scan collection method from main.py Instead use a method wrapper to perform the file clean up. --- ospd/main.py | 2 +- ospd/ospd.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/ospd/main.py b/ospd/main.py index f097baf3..29966599 100644 --- a/ospd/main.py +++ b/ospd/main.py @@ -118,7 +118,7 @@ def exit_cleanup( signal.signal(signal.SIGINT, signal.SIG_IGN) pidpath = Path(pidfile) - daemon.scan_collection.clean_up_pickled_scan_info() + daemon.daemon_exit_cleanup() if not pidpath.is_file(): return diff --git a/ospd/ospd.py b/ospd/ospd.py index 0899ea73..4115ab00 100644 --- a/ospd/ospd.py +++ b/ospd/ospd.py @@ -432,6 +432,10 @@ def finish_scan(self, scan_id: str) -> None: self.set_scan_status(scan_id, ScanStatus.FINISHED) logger.info("%s: Scan finished.", scan_id) + def daemon_exit_cleanup(self): + """ Perform a cleanup before exiting """ + self.scan_collection.clean_up_pickled_scan_info() + def get_daemon_name(self) -> str: """ Gives osp daemon's name. """ return self.daemon_info['name'] From 3a35e06645b91b0d582b82ca972e7023fd0477a3 Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Thu, 4 Jun 2020 12:05:19 +0200 Subject: [PATCH 29/31] Return immediately if unpickling fails. Also, remove the file before raising an error when the unpickling fails. --- ospd/datapickler.py | 8 ++++++-- ospd/scan.py | 3 ++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/ospd/datapickler.py b/ospd/datapickler.py index 67539897..3a5fe64c 100644 --- a/ospd/datapickler.py +++ b/ospd/datapickler.py @@ -123,11 +123,15 @@ def load_data(self, filename: str, original_data_hash: str) -> Any: logger.error( 'Not possible to read pickled data from %s. %s', filename, e ) + return pickled_scan_info_hash = self._pickled_data_hash_generator(pickled_data) - if original_data_hash == pickled_scan_info_hash: - return unpickled_scan_info + if original_data_hash != pickled_scan_info_hash: + logger.error('Unpickled data from %s corrupted.', filename) + return + + return unpickled_scan_info def _pickled_data_hash_generator(self, pickled_data: bytes) -> str: """ Calculate the sha256 hash of a pickled data """ diff --git a/ospd/scan.py b/ospd/scan.py index ccb81e7b..7eb7e99d 100644 --- a/ospd/scan.py +++ b/ospd/scan.py @@ -58,7 +58,7 @@ class ScanCollection: """ - def __init__(self, file_storage_dir) -> None: + def __init__(self, file_storage_dir: str) -> None: """ Initialize the Scan Collection. """ self.data_manager = ( @@ -225,6 +225,7 @@ def unpickle_scan_info(self, scan_id): unpickled_scan_info = pickler.load_data(scan_id, scan_info_hash) if not unpickled_scan_info: + pickler.remove_file(scan_id) raise OspdCommandError( 'Not possible to unpickle stored scan info for %s' % scan_id, 'start_scan', From c09489bbda8884bd1496d0c4ed1119fef641f8b8 Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Thu, 4 Jun 2020 12:13:36 +0200 Subject: [PATCH 30/31] Check for min_free_mem_scan_queue inside is_enough_free_memory() --- ospd/ospd.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ospd/ospd.py b/ospd/ospd.py index 4115ab00..822d570b 100644 --- a/ospd/ospd.py +++ b/ospd/ospd.py @@ -1220,10 +1220,7 @@ def start_queued_scans(self): ) return - if ( - self.min_free_mem_scan_queue - and not self.is_enough_free_memory() - ): + if not self.is_enough_free_memory(): logger.debug( 'Not possible to run a new scan. Not enough free memory.' ) @@ -1262,6 +1259,9 @@ def is_enough_free_memory(self) -> bool: Return: True if there is enough memory for a new scan. """ + if not self.min_free_mem_scan_queue: + return True + free_mem = psutil.virtual_memory().free if (free_mem / (1024 * 1024)) > self.min_free_mem_scan_queue: From 4081913adc5d4ce7d60e015ffb0a5e146d36637a Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Thu, 4 Jun 2020 14:38:45 +0200 Subject: [PATCH 31/31] More typing --- ospd/ospd.py | 4 ++-- ospd/scan.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/ospd/ospd.py b/ospd/ospd.py index 822d570b..610a08b7 100644 --- a/ospd/ospd.py +++ b/ospd/ospd.py @@ -432,7 +432,7 @@ def finish_scan(self, scan_id: str) -> None: self.set_scan_status(scan_id, ScanStatus.FINISHED) logger.info("%s: Scan finished.", scan_id) - def daemon_exit_cleanup(self): + def daemon_exit_cleanup(self) -> None: """ Perform a cleanup before exiting """ self.scan_collection.clean_up_pickled_scan_info() @@ -1210,7 +1210,7 @@ def run(self) -> None: except KeyboardInterrupt: logger.info("Received Ctrl-C shutting-down ...") - def start_queued_scans(self): + def start_queued_scans(self) -> None: """ Starts a queued scan if it is allowed """ for scan_id in self.scan_collection.ids_iterator(): diff --git a/ospd/scan.py b/ospd/scan.py index 7eb7e99d..07289395 100644 --- a/ospd/scan.py +++ b/ospd/scan.py @@ -204,17 +204,17 @@ def ids_iterator(self) -> Iterator[str]: return iter(self.scans_table.keys()) - def clean_up_pickled_scan_info(self): + def clean_up_pickled_scan_info(self) -> None: """ Remove files of pickled scan info """ for scan_id in self.ids_iterator(): if self.get_status(scan_id) == ScanStatus.QUEUED: self.remove_file_pickled_scan_info(scan_id) - def remove_file_pickled_scan_info(self, scan_id): + def remove_file_pickled_scan_info(self, scan_id: str) -> None: pickler = DataPickler(self.file_storage_dir) pickler.remove_file(scan_id) - def unpickle_scan_info(self, scan_id): + def unpickle_scan_info(self, scan_id: str) -> None: """ Unpickle a stored scan_inf correspinding to the scan_id and store it in the scan_table """