diff --git a/ospd/misc.py b/ospd/misc.py
index 9378e82f..a746d9f9 100644
--- a/ospd/misc.py
+++ b/ospd/misc.py
@@ -40,6 +40,7 @@
import uuid
import multiprocessing
import itertools
+from enum import Enum
LOGGER = logging.getLogger(__name__)
@@ -51,6 +52,14 @@
PORT = 1234
ADDRESS = "0.0.0.0"
+class ScanStatus(Enum):
+ """Scan status. """
+ INIT = 0
+ RUNNING = 1
+ STOPPED = 2
+ FINISHED = 3
+
+
class ScanCollection(object):
@@ -103,15 +112,14 @@ def set_progress(self, scan_id, progress):
if progress == 100:
self.scans_table[scan_id]['end_time'] = int(time.time())
- def set_target_progress(self, scan_id, target, progress):
+ def set_target_progress(self, scan_id, target, host, progress):
""" Sets scan_id scan's progress. """
if progress > 0 and progress <= 100:
- target_process = dict()
- target_process = self.scans_table[scan_id]['target_progress']
- target_process[target] = progress
+ targets = self.scans_table[scan_id]['target_progress']
+ targets[target][host] = progress
# Set scan_info's target_progress to propagate progresses
# to parent process.
- self.scans_table[scan_id]['target_progress'] = target_process
+ self.scans_table[scan_id]['target_progress'] = targets
def set_host_finished(self, scan_id, target, host):
""" Add the host in a list of finished hosts """
@@ -135,26 +143,26 @@ def ids_iterator(self):
return iter(self.scans_table.keys())
- def create_scan(self, scan_id='', targets='', target_str=None,
- options=dict(), vts=''):
+ def create_scan(self, scan_id='', targets='', options=None, vts=''):
""" Creates a new scan with provided scan information. """
if self.data_manager is None:
self.data_manager = multiprocessing.Manager()
+ if not options:
+ options = dict()
scan_info = self.data_manager.dict()
scan_info['results'] = list()
scan_info['finished_hosts'] = dict(
[[target, []] for target, _, _ in targets])
scan_info['progress'] = 0
scan_info['target_progress'] = dict(
- [[target, 0] for target, _, _ in targets])
+ [[target, {}] for target, _, _ in targets])
scan_info['targets'] = targets
- scan_info['legacy_target'] = target_str
scan_info['vts'] = vts
scan_info['options'] = options
scan_info['start_time'] = int(time.time())
scan_info['end_time'] = "0"
- scan_info['status'] = ""
+ scan_info['status'] = ScanStatus.INIT
if scan_id is None or scan_id == '':
scan_id = str(uuid.uuid4())
scan_info['scan_id'] = scan_id
@@ -185,10 +193,19 @@ def get_progress(self, scan_id):
return self.scans_table[scan_id]['progress']
- def get_target_progress(self, scan_id):
- """ Get a scan's current progress value. """
+ def get_target_progress(self, scan_id, target):
+ """ Get a target's current progress value.
+ The value is calculated with the progress of each single host
+ in the target."""
- return self.scans_table[scan_id]['target_progress']
+ total_hosts = len(target_str_to_list(target))
+ host_progresses = self.scans_table[scan_id]['target_progress'].get(target)
+ try:
+ t_prog = sum(host_progresses.values()) / total_hosts
+ except ZeroDivisionError:
+ LOGGER.error("Zero division error in ", get_target_progress.__name__)
+ raise
+ return t_prog
def get_start_time(self, scan_id):
""" Get a scan's start time. """
@@ -200,16 +217,13 @@ def get_end_time(self, scan_id):
return self.scans_table[scan_id]['end_time']
- def get_target(self, scan_id):
+ def get_target_list(self, scan_id):
""" Get a scan's target list. """
- if self.scans_table[scan_id]['legacy_target']:
- return self.scans_table[scan_id]['legacy_target']
target_list = []
- for item in self.scans_table[scan_id]['targets']:
- target_list.append(item[0])
- separ = ','
- return separ.join(target_list)
+ for target, _, _ in self.scans_table[scan_id]['targets']:
+ target_list.append(target)
+ return target_list
def get_ports(self, scan_id, target):
""" Get a scan's ports list. If a target is specified
@@ -246,7 +260,7 @@ def id_exists(self, scan_id):
def delete_scan(self, scan_id):
""" Delete a scan if fully finished. """
- if self.get_status(scan_id) == "running":
+ if self.get_status(scan_id) == ScanStatus.RUNNING:
return False
self.scans_table.pop(scan_id)
if len(self.scans_table) == 0:
diff --git a/ospd/ospd.py b/ospd/ospd.py
index 0b5d459e..b904e1c4 100644
--- a/ospd/ospd.py
+++ b/ospd/ospd.py
@@ -40,6 +40,7 @@
from ospd.vtfilter import VtsFilter
from ospd.misc import ScanCollection, ResultType, target_str_to_list
from ospd.misc import resolve_hostname, valid_uuid
+from ospd.misc import ScanStatus
from ospd.xml import simple_response_str, get_result_xml
from ospd.error import OSPDError
@@ -603,8 +604,7 @@ def handle_start_scan_command(self, scan_et):
scan_params = self.process_scan_params(params)
scan_id = self.create_scan(scan_id, scan_targets,
- target_str, scan_params,
- vt_selection)
+ scan_params, vt_selection)
scan_process = multiprocessing.Process(target=scan_func,
args=(scan_id,
scan_targets,
@@ -635,7 +635,7 @@ def stop_scan(self, scan_id):
if not scan_process.is_alive():
raise OSPDError('Scan already stopped or finished.', 'stop_scan')
- self.set_scan_status(scan_id, "stopped")
+ self.set_scan_status(scan_id, ScanStatus.STOPPED)
logger.info('%s: Scan stopping %s.', scan_id, scan_process.ident)
self.stop_scan_cleanup(scan_id)
try:
@@ -661,7 +661,7 @@ def exec_scan(self, scan_id, target):
def finish_scan(self, scan_id):
""" Sets a scan as finished. """
self.set_scan_progress(scan_id, 100)
- self.set_scan_status(scan_id, 'finished')
+ self.set_scan_status(scan_id, ScanStatus.FINISHED)
logger.info("%s: Scan finished.", scan_id)
def get_daemon_name(self):
@@ -837,8 +837,9 @@ def check_pending_target(self, scan_id, multiscan_proc):
"""
for running_target_proc, running_target_id in multiscan_proc:
if not running_target_proc.is_alive():
- target_prog = self.get_scan_target_progress(scan_id)
- if target_prog[running_target_id] < 100:
+ target_prog = self.get_scan_target_progress(
+ scan_id, running_target_id)
+ if target_prog < 100:
self.stop_scan(scan_id)
running_target = (running_target_proc, running_target_id)
multiscan_proc.remove(running_target)
@@ -848,8 +849,10 @@ def calculate_progress(self, scan_id):
""" Calculate the total scan progress from the
partial target progress. """
- target_progress = self.get_scan_target_progress(scan_id)
- return sum(target_progress.values())/len(target_progress)
+ t_prog = dict()
+ for target in self.get_scan_target(scan_id):
+ t_prog[target] = self.get_scan_target_progress(scan_id, target)
+ return sum(t_prog.values())/len(t_prog)
def start_scan(self, scan_id, targets, parallel=1):
""" Handle N parallel scans if 'parallel' is greater than 1. """
@@ -863,14 +866,14 @@ def start_scan(self, scan_id, targets, parallel=1):
for index, target in enumerate(target_list):
while len(multiscan_proc) >= parallel:
- multiscan_proc = self.check_pending_target(scan_id,
- multiscan_proc)
progress = self.calculate_progress(scan_id)
self.set_scan_progress(scan_id, progress)
+ multiscan_proc = self.check_pending_target(scan_id,
+ multiscan_proc)
time.sleep(1)
#If the scan status is stopped, does not launch anymore target scans
- if self.get_scan_status(scan_id) == "stopped":
+ if self.get_scan_status(scan_id) == ScanStatus.STOPPED:
return
logger.info("%s: Host scan started on ports %s.", target[0], target[1])
@@ -878,7 +881,7 @@ def start_scan(self, scan_id, targets, parallel=1):
args=(scan_id, target[0]))
multiscan_proc.append((scan_process, target[0]))
scan_process.start()
- self.set_scan_status(scan_id, "running")
+ self.set_scan_status(scan_id, ScanStatus.RUNNING)
# Wait until all single target were scanned
while multiscan_proc:
@@ -889,14 +892,13 @@ def start_scan(self, scan_id, targets, parallel=1):
time.sleep(1)
# Only set the scan as finished if the scan was not stopped.
- if self.get_scan_status(scan_id) != "stopped":
+ if self.get_scan_status(scan_id) != ScanStatus.STOPPED:
self.finish_scan(scan_id)
def dry_run_scan(self, scan_id, targets):
""" Dry runs a scan. """
os.setsid()
- #target_list = target_str_to_list(target_str)
for _, target in enumerate(targets):
host = resolve_hostname(target[0])
if host is None:
@@ -919,12 +921,15 @@ def set_scan_host_finished(self, scan_id, target, host):
self.scan_collection.set_host_finished(scan_id, target, host)
def set_scan_progress(self, scan_id, progress):
- """ Sets scan_id scan's progress which is a number between 0 and 100. """
+ """ Sets scan_id scan's progress which is a number
+ between 0 and 100. """
self.scan_collection.set_progress(scan_id, progress)
- def set_scan_target_progress(self, scan_id, target, progress):
- """ Sets target's progress. """
- self.scan_collection.set_target_progress(scan_id, target, progress)
+ def set_scan_target_progress(
+ self, scan_id, target, host, progress):
+ """ Sets host's progress which is part of target. """
+ self.scan_collection.set_target_progress(
+ scan_id, target, host, progress)
def set_scan_status(self, scan_id, status):
""" Set the scan's status."""
@@ -1072,6 +1077,9 @@ def delete_scan(self, scan_id):
@return: 1 if scan deleted, 0 otherwise.
"""
+ if self.get_scan_status(scan_id) == ScanStatus.RUNNING:
+ return 0
+
try:
del self.scan_processes[scan_id]
except KeyError:
@@ -1120,7 +1128,7 @@ def get_scan_xml(self, scan_id, detailed=True, pop_res=False):
if not scan_id:
return Element('scan')
- target = self.get_scan_target(scan_id)
+ target = ','.join(self.get_scan_target(scan_id))
progress = self.get_scan_progress(scan_id)
status = self.get_scan_status(scan_id)
start_time = self.get_scan_start_time(scan_id)
@@ -1129,7 +1137,7 @@ def get_scan_xml(self, scan_id, detailed=True, pop_res=False):
for name, value in [('id', scan_id),
('target', target),
('progress', progress),
- ('status', status),
+ ('status', status.name.lower()),
('start_time', start_time),
('end_time', end_time)]:
response.set(name, str(value))
@@ -1555,7 +1563,7 @@ def scheduler(self):
""" Should be implemented by subclass in case of need
to run tasks periodically. """
- def create_scan(self, scan_id, targets, target_str, options, vts):
+ def create_scan(self, scan_id, targets, options, vts):
""" Creates a new scan.
@target: Target to scan.
@@ -1563,7 +1571,7 @@ def create_scan(self, scan_id, targets, target_str, options, vts):
@return: New scan's ID.
"""
- return self.scan_collection.create_scan(scan_id, targets, target_str, options, vts)
+ return self.scan_collection.create_scan(scan_id, targets, options, vts)
def get_scan_options(self, scan_id):
""" Gives a scan's list of options. """
@@ -1578,7 +1586,7 @@ def check_scan_process(self, scan_id):
scan_process = self.scan_processes[scan_id]
progress = self.get_scan_progress(scan_id)
if progress < 100 and not scan_process.is_alive():
- self.set_scan_status(scan_id, 'stopped')
+ self.set_scan_status(scan_id, ScanStatus.STOPPED)
self.add_scan_error(scan_id, name="", host="",
value="Scan process failure.")
logger.info("%s: Scan stopped with errors.", scan_id)
@@ -1589,13 +1597,13 @@ def get_scan_progress(self, scan_id):
""" Gives a scan's current progress value. """
return self.scan_collection.get_progress(scan_id)
- def get_scan_target_progress(self, scan_id):
+ def get_scan_target_progress(self, scan_id, target):
""" Gives a list with scan's current progress value of each target. """
- return self.scan_collection.get_target_progress(scan_id)
+ return self.scan_collection.get_target_progress(scan_id, target)
def get_scan_target(self, scan_id):
""" Gives a scan's target. """
- return self.scan_collection.get_target(scan_id)
+ return self.scan_collection.get_target_list(scan_id)
def get_scan_ports(self, scan_id, target=''):
""" Gives a scan's ports list. """
diff --git a/tests/testSSHDaemon.py b/tests/testSSHDaemon.py
index 894bb7f8..77cedf23 100644
--- a/tests/testSSHDaemon.py
+++ b/tests/testSSHDaemon.py
@@ -69,7 +69,7 @@ def testNoParamiko(self):
def testRunCommand(self):
ospd_ssh.paramiko = fakeparamiko
daemon = OSPDaemonSimpleSSH('cert', 'key', 'ca')
- scanid = daemon.create_scan(None, [['host.example.com', '80, 443', ''],], '',
+ scanid = daemon.create_scan(None, [['host.example.com', '80, 443', ''],],
dict(port=5, ssh_timeout=15,
username_password='dummy:pw'), '')
res = daemon.run_command(scanid, 'host.example.com', 'cat /etc/passwd')
@@ -79,7 +79,7 @@ def testRunCommand(self):
def testRunCommandLegacyCredential(self):
ospd_ssh.paramiko = fakeparamiko
daemon = OSPDaemonSimpleSSH('cert', 'key', 'ca')
- scanid = daemon.create_scan(None, [['host.example.com', '80, 443', ''],], '',
+ scanid = daemon.create_scan(None, [['host.example.com', '80, 443', ''],],
dict(port=5, ssh_timeout=15,
username='dummy', password='pw'), '')
res = daemon.run_command(scanid, 'host.example.com', 'cat /etc/passwd')
@@ -99,7 +99,7 @@ def testRunCommandNewCredential(self):
'username': 'smbuser'}}
scanid = daemon.create_scan(None,
- [['host.example.com', '80, 443', cred_dict],], '',
+ [['host.example.com', '80, 443', cred_dict],],
dict(port=5, ssh_timeout=15), '')
res = daemon.run_command(scanid, 'host.example.com', 'cat /etc/passwd')
self.assertTrue(isinstance(res, list))
@@ -108,7 +108,7 @@ def testRunCommandNewCredential(self):
def testRunCommandNoCredential(self):
ospd_ssh.paramiko = fakeparamiko
daemon = OSPDaemonSimpleSSH('cert', 'key', 'ca')
- scanid = daemon.create_scan(None, [['host.example.com', '80, 443', ''],], '',
+ scanid = daemon.create_scan(None, [['host.example.com', '80, 443', ''],],
dict(port=5, ssh_timeout=15), '')
self.assertRaises(ValueError, daemon.run_command,
scanid, 'host.example.com', 'cat /etc/passwd' )
diff --git a/tests/testScanAndResult.py b/tests/testScanAndResult.py
index 26ed9bc7..8b0214d6 100644
--- a/tests/testScanAndResult.py
+++ b/tests/testScanAndResult.py
@@ -458,7 +458,7 @@ def testScanWithError(self):
self.assertEqual(1, len(scans))
scan = scans[0]
status = scan.get('status')
- if not status or status == 'running':
+ if status == "init" or status == "running":
self.assertEqual('0', scan.get('end_time'))
time.sleep(.010)
else:
@@ -696,21 +696,6 @@ def testScanGetTarget(self):
scan_res = response.find('scan')
self.assertEqual(scan_res.get('target'), 'localhosts,192.168.0.0/24')
- def testScanGetLegacyTarget(self):
- daemon = DummyWrapper([])
-
- response = secET.fromstring(
- daemon.handle_command(
- ''
- ''
- ''
- ''))
- scan_id = response.findtext('id')
- response = secET.fromstring(
- daemon.handle_command('' % scan_id))
- scan_res = response.find('scan')
- self.assertEqual(scan_res.get('target'), 'localhosts,192.168.0.0/24')
-
def testScanMultiTargetParallelWithError(self):
daemon = DummyWrapper([])
cmd = secET.fromstring(''
@@ -750,8 +735,8 @@ def testProgress(self):
''
''))
scan_id = response.findtext('id')
- daemon.set_scan_target_progress(scan_id, 'localhost1', 75)
- daemon.set_scan_target_progress(scan_id, 'localhost2', 25)
+ daemon.set_scan_target_progress(scan_id, 'localhost1', 'localhost1', 75)
+ daemon.set_scan_target_progress(scan_id, 'localhost2', 'localhost2', 25)
self.assertEqual(daemon.calculate_progress(scan_id), 50)
def testSetGetVtsVersion(self):