diff --git a/ospd/misc.py b/ospd/misc.py index 9378e82f..a746d9f9 100644 --- a/ospd/misc.py +++ b/ospd/misc.py @@ -40,6 +40,7 @@ import uuid import multiprocessing import itertools +from enum import Enum LOGGER = logging.getLogger(__name__) @@ -51,6 +52,14 @@ PORT = 1234 ADDRESS = "0.0.0.0" +class ScanStatus(Enum): + """Scan status. """ + INIT = 0 + RUNNING = 1 + STOPPED = 2 + FINISHED = 3 + + class ScanCollection(object): @@ -103,15 +112,14 @@ def set_progress(self, scan_id, progress): if progress == 100: self.scans_table[scan_id]['end_time'] = int(time.time()) - def set_target_progress(self, scan_id, target, progress): + def set_target_progress(self, scan_id, target, host, progress): """ Sets scan_id scan's progress. """ if progress > 0 and progress <= 100: - target_process = dict() - target_process = self.scans_table[scan_id]['target_progress'] - target_process[target] = progress + targets = self.scans_table[scan_id]['target_progress'] + targets[target][host] = progress # Set scan_info's target_progress to propagate progresses # to parent process. - self.scans_table[scan_id]['target_progress'] = target_process + self.scans_table[scan_id]['target_progress'] = targets def set_host_finished(self, scan_id, target, host): """ Add the host in a list of finished hosts """ @@ -135,26 +143,26 @@ def ids_iterator(self): return iter(self.scans_table.keys()) - def create_scan(self, scan_id='', targets='', target_str=None, - options=dict(), vts=''): + def create_scan(self, scan_id='', targets='', options=None, vts=''): """ Creates a new scan with provided scan information. """ if self.data_manager is None: self.data_manager = multiprocessing.Manager() + if not options: + options = dict() scan_info = self.data_manager.dict() scan_info['results'] = list() scan_info['finished_hosts'] = dict( [[target, []] for target, _, _ in targets]) scan_info['progress'] = 0 scan_info['target_progress'] = dict( - [[target, 0] for target, _, _ in targets]) + [[target, {}] for target, _, _ in targets]) scan_info['targets'] = targets - scan_info['legacy_target'] = target_str scan_info['vts'] = vts scan_info['options'] = options scan_info['start_time'] = int(time.time()) scan_info['end_time'] = "0" - scan_info['status'] = "" + scan_info['status'] = ScanStatus.INIT if scan_id is None or scan_id == '': scan_id = str(uuid.uuid4()) scan_info['scan_id'] = scan_id @@ -185,10 +193,19 @@ def get_progress(self, scan_id): return self.scans_table[scan_id]['progress'] - def get_target_progress(self, scan_id): - """ Get a scan's current progress value. """ + def get_target_progress(self, scan_id, target): + """ Get a target's current progress value. + The value is calculated with the progress of each single host + in the target.""" - return self.scans_table[scan_id]['target_progress'] + total_hosts = len(target_str_to_list(target)) + host_progresses = self.scans_table[scan_id]['target_progress'].get(target) + try: + t_prog = sum(host_progresses.values()) / total_hosts + except ZeroDivisionError: + LOGGER.error("Zero division error in ", get_target_progress.__name__) + raise + return t_prog def get_start_time(self, scan_id): """ Get a scan's start time. """ @@ -200,16 +217,13 @@ def get_end_time(self, scan_id): return self.scans_table[scan_id]['end_time'] - def get_target(self, scan_id): + def get_target_list(self, scan_id): """ Get a scan's target list. """ - if self.scans_table[scan_id]['legacy_target']: - return self.scans_table[scan_id]['legacy_target'] target_list = [] - for item in self.scans_table[scan_id]['targets']: - target_list.append(item[0]) - separ = ',' - return separ.join(target_list) + for target, _, _ in self.scans_table[scan_id]['targets']: + target_list.append(target) + return target_list def get_ports(self, scan_id, target): """ Get a scan's ports list. If a target is specified @@ -246,7 +260,7 @@ def id_exists(self, scan_id): def delete_scan(self, scan_id): """ Delete a scan if fully finished. """ - if self.get_status(scan_id) == "running": + if self.get_status(scan_id) == ScanStatus.RUNNING: return False self.scans_table.pop(scan_id) if len(self.scans_table) == 0: diff --git a/ospd/ospd.py b/ospd/ospd.py index 0b5d459e..b904e1c4 100644 --- a/ospd/ospd.py +++ b/ospd/ospd.py @@ -40,6 +40,7 @@ from ospd.vtfilter import VtsFilter from ospd.misc import ScanCollection, ResultType, target_str_to_list from ospd.misc import resolve_hostname, valid_uuid +from ospd.misc import ScanStatus from ospd.xml import simple_response_str, get_result_xml from ospd.error import OSPDError @@ -603,8 +604,7 @@ def handle_start_scan_command(self, scan_et): scan_params = self.process_scan_params(params) scan_id = self.create_scan(scan_id, scan_targets, - target_str, scan_params, - vt_selection) + scan_params, vt_selection) scan_process = multiprocessing.Process(target=scan_func, args=(scan_id, scan_targets, @@ -635,7 +635,7 @@ def stop_scan(self, scan_id): if not scan_process.is_alive(): raise OSPDError('Scan already stopped or finished.', 'stop_scan') - self.set_scan_status(scan_id, "stopped") + self.set_scan_status(scan_id, ScanStatus.STOPPED) logger.info('%s: Scan stopping %s.', scan_id, scan_process.ident) self.stop_scan_cleanup(scan_id) try: @@ -661,7 +661,7 @@ def exec_scan(self, scan_id, target): def finish_scan(self, scan_id): """ Sets a scan as finished. """ self.set_scan_progress(scan_id, 100) - self.set_scan_status(scan_id, 'finished') + self.set_scan_status(scan_id, ScanStatus.FINISHED) logger.info("%s: Scan finished.", scan_id) def get_daemon_name(self): @@ -837,8 +837,9 @@ def check_pending_target(self, scan_id, multiscan_proc): """ for running_target_proc, running_target_id in multiscan_proc: if not running_target_proc.is_alive(): - target_prog = self.get_scan_target_progress(scan_id) - if target_prog[running_target_id] < 100: + target_prog = self.get_scan_target_progress( + scan_id, running_target_id) + if target_prog < 100: self.stop_scan(scan_id) running_target = (running_target_proc, running_target_id) multiscan_proc.remove(running_target) @@ -848,8 +849,10 @@ def calculate_progress(self, scan_id): """ Calculate the total scan progress from the partial target progress. """ - target_progress = self.get_scan_target_progress(scan_id) - return sum(target_progress.values())/len(target_progress) + t_prog = dict() + for target in self.get_scan_target(scan_id): + t_prog[target] = self.get_scan_target_progress(scan_id, target) + return sum(t_prog.values())/len(t_prog) def start_scan(self, scan_id, targets, parallel=1): """ Handle N parallel scans if 'parallel' is greater than 1. """ @@ -863,14 +866,14 @@ def start_scan(self, scan_id, targets, parallel=1): for index, target in enumerate(target_list): while len(multiscan_proc) >= parallel: - multiscan_proc = self.check_pending_target(scan_id, - multiscan_proc) progress = self.calculate_progress(scan_id) self.set_scan_progress(scan_id, progress) + multiscan_proc = self.check_pending_target(scan_id, + multiscan_proc) time.sleep(1) #If the scan status is stopped, does not launch anymore target scans - if self.get_scan_status(scan_id) == "stopped": + if self.get_scan_status(scan_id) == ScanStatus.STOPPED: return logger.info("%s: Host scan started on ports %s.", target[0], target[1]) @@ -878,7 +881,7 @@ def start_scan(self, scan_id, targets, parallel=1): args=(scan_id, target[0])) multiscan_proc.append((scan_process, target[0])) scan_process.start() - self.set_scan_status(scan_id, "running") + self.set_scan_status(scan_id, ScanStatus.RUNNING) # Wait until all single target were scanned while multiscan_proc: @@ -889,14 +892,13 @@ def start_scan(self, scan_id, targets, parallel=1): time.sleep(1) # Only set the scan as finished if the scan was not stopped. - if self.get_scan_status(scan_id) != "stopped": + if self.get_scan_status(scan_id) != ScanStatus.STOPPED: self.finish_scan(scan_id) def dry_run_scan(self, scan_id, targets): """ Dry runs a scan. """ os.setsid() - #target_list = target_str_to_list(target_str) for _, target in enumerate(targets): host = resolve_hostname(target[0]) if host is None: @@ -919,12 +921,15 @@ def set_scan_host_finished(self, scan_id, target, host): self.scan_collection.set_host_finished(scan_id, target, host) def set_scan_progress(self, scan_id, progress): - """ Sets scan_id scan's progress which is a number between 0 and 100. """ + """ Sets scan_id scan's progress which is a number + between 0 and 100. """ self.scan_collection.set_progress(scan_id, progress) - def set_scan_target_progress(self, scan_id, target, progress): - """ Sets target's progress. """ - self.scan_collection.set_target_progress(scan_id, target, progress) + def set_scan_target_progress( + self, scan_id, target, host, progress): + """ Sets host's progress which is part of target. """ + self.scan_collection.set_target_progress( + scan_id, target, host, progress) def set_scan_status(self, scan_id, status): """ Set the scan's status.""" @@ -1072,6 +1077,9 @@ def delete_scan(self, scan_id): @return: 1 if scan deleted, 0 otherwise. """ + if self.get_scan_status(scan_id) == ScanStatus.RUNNING: + return 0 + try: del self.scan_processes[scan_id] except KeyError: @@ -1120,7 +1128,7 @@ def get_scan_xml(self, scan_id, detailed=True, pop_res=False): if not scan_id: return Element('scan') - target = self.get_scan_target(scan_id) + target = ','.join(self.get_scan_target(scan_id)) progress = self.get_scan_progress(scan_id) status = self.get_scan_status(scan_id) start_time = self.get_scan_start_time(scan_id) @@ -1129,7 +1137,7 @@ def get_scan_xml(self, scan_id, detailed=True, pop_res=False): for name, value in [('id', scan_id), ('target', target), ('progress', progress), - ('status', status), + ('status', status.name.lower()), ('start_time', start_time), ('end_time', end_time)]: response.set(name, str(value)) @@ -1555,7 +1563,7 @@ def scheduler(self): """ Should be implemented by subclass in case of need to run tasks periodically. """ - def create_scan(self, scan_id, targets, target_str, options, vts): + def create_scan(self, scan_id, targets, options, vts): """ Creates a new scan. @target: Target to scan. @@ -1563,7 +1571,7 @@ def create_scan(self, scan_id, targets, target_str, options, vts): @return: New scan's ID. """ - return self.scan_collection.create_scan(scan_id, targets, target_str, options, vts) + return self.scan_collection.create_scan(scan_id, targets, options, vts) def get_scan_options(self, scan_id): """ Gives a scan's list of options. """ @@ -1578,7 +1586,7 @@ def check_scan_process(self, scan_id): scan_process = self.scan_processes[scan_id] progress = self.get_scan_progress(scan_id) if progress < 100 and not scan_process.is_alive(): - self.set_scan_status(scan_id, 'stopped') + self.set_scan_status(scan_id, ScanStatus.STOPPED) self.add_scan_error(scan_id, name="", host="", value="Scan process failure.") logger.info("%s: Scan stopped with errors.", scan_id) @@ -1589,13 +1597,13 @@ def get_scan_progress(self, scan_id): """ Gives a scan's current progress value. """ return self.scan_collection.get_progress(scan_id) - def get_scan_target_progress(self, scan_id): + def get_scan_target_progress(self, scan_id, target): """ Gives a list with scan's current progress value of each target. """ - return self.scan_collection.get_target_progress(scan_id) + return self.scan_collection.get_target_progress(scan_id, target) def get_scan_target(self, scan_id): """ Gives a scan's target. """ - return self.scan_collection.get_target(scan_id) + return self.scan_collection.get_target_list(scan_id) def get_scan_ports(self, scan_id, target=''): """ Gives a scan's ports list. """ diff --git a/tests/testSSHDaemon.py b/tests/testSSHDaemon.py index 894bb7f8..77cedf23 100644 --- a/tests/testSSHDaemon.py +++ b/tests/testSSHDaemon.py @@ -69,7 +69,7 @@ def testNoParamiko(self): def testRunCommand(self): ospd_ssh.paramiko = fakeparamiko daemon = OSPDaemonSimpleSSH('cert', 'key', 'ca') - scanid = daemon.create_scan(None, [['host.example.com', '80, 443', ''],], '', + scanid = daemon.create_scan(None, [['host.example.com', '80, 443', ''],], dict(port=5, ssh_timeout=15, username_password='dummy:pw'), '') res = daemon.run_command(scanid, 'host.example.com', 'cat /etc/passwd') @@ -79,7 +79,7 @@ def testRunCommand(self): def testRunCommandLegacyCredential(self): ospd_ssh.paramiko = fakeparamiko daemon = OSPDaemonSimpleSSH('cert', 'key', 'ca') - scanid = daemon.create_scan(None, [['host.example.com', '80, 443', ''],], '', + scanid = daemon.create_scan(None, [['host.example.com', '80, 443', ''],], dict(port=5, ssh_timeout=15, username='dummy', password='pw'), '') res = daemon.run_command(scanid, 'host.example.com', 'cat /etc/passwd') @@ -99,7 +99,7 @@ def testRunCommandNewCredential(self): 'username': 'smbuser'}} scanid = daemon.create_scan(None, - [['host.example.com', '80, 443', cred_dict],], '', + [['host.example.com', '80, 443', cred_dict],], dict(port=5, ssh_timeout=15), '') res = daemon.run_command(scanid, 'host.example.com', 'cat /etc/passwd') self.assertTrue(isinstance(res, list)) @@ -108,7 +108,7 @@ def testRunCommandNewCredential(self): def testRunCommandNoCredential(self): ospd_ssh.paramiko = fakeparamiko daemon = OSPDaemonSimpleSSH('cert', 'key', 'ca') - scanid = daemon.create_scan(None, [['host.example.com', '80, 443', ''],], '', + scanid = daemon.create_scan(None, [['host.example.com', '80, 443', ''],], dict(port=5, ssh_timeout=15), '') self.assertRaises(ValueError, daemon.run_command, scanid, 'host.example.com', 'cat /etc/passwd' ) diff --git a/tests/testScanAndResult.py b/tests/testScanAndResult.py index 26ed9bc7..8b0214d6 100644 --- a/tests/testScanAndResult.py +++ b/tests/testScanAndResult.py @@ -458,7 +458,7 @@ def testScanWithError(self): self.assertEqual(1, len(scans)) scan = scans[0] status = scan.get('status') - if not status or status == 'running': + if status == "init" or status == "running": self.assertEqual('0', scan.get('end_time')) time.sleep(.010) else: @@ -696,21 +696,6 @@ def testScanGetTarget(self): scan_res = response.find('scan') self.assertEqual(scan_res.get('target'), 'localhosts,192.168.0.0/24') - def testScanGetLegacyTarget(self): - daemon = DummyWrapper([]) - - response = secET.fromstring( - daemon.handle_command( - '' - '' - '' - '')) - scan_id = response.findtext('id') - response = secET.fromstring( - daemon.handle_command('' % scan_id)) - scan_res = response.find('scan') - self.assertEqual(scan_res.get('target'), 'localhosts,192.168.0.0/24') - def testScanMultiTargetParallelWithError(self): daemon = DummyWrapper([]) cmd = secET.fromstring('' @@ -750,8 +735,8 @@ def testProgress(self): '' '')) scan_id = response.findtext('id') - daemon.set_scan_target_progress(scan_id, 'localhost1', 75) - daemon.set_scan_target_progress(scan_id, 'localhost2', 25) + daemon.set_scan_target_progress(scan_id, 'localhost1', 'localhost1', 75) + daemon.set_scan_target_progress(scan_id, 'localhost2', 'localhost2', 25) self.assertEqual(daemon.calculate_progress(scan_id), 50) def testSetGetVtsVersion(self):