Skip to content
This repository has been archived by the owner on Nov 29, 2021. It is now read-only.

Fix progress and improvements. #101

Merged
merged 6 commits into from
Apr 15, 2019
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 30 additions & 20 deletions ospd/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import uuid
import multiprocessing
import itertools
from enum import Enum

LOGGER = logging.getLogger(__name__)

Expand All @@ -51,6 +52,14 @@
PORT = 1234
ADDRESS = "0.0.0.0"

class ScanStatus(Enum):
"""Scan status. """
INIT = 0
RUNNING = 1
STOPPED = 2
FINISHED = 3



class ScanCollection(object):

Expand Down Expand Up @@ -103,15 +112,14 @@ def set_progress(self, scan_id, progress):
if progress == 100:
self.scans_table[scan_id]['end_time'] = int(time.time())

def set_target_progress(self, scan_id, target, progress):
def set_target_progress(self, scan_id, target, host, progress):
""" Sets scan_id scan's progress. """
if progress > 0 and progress <= 100:
target_process = dict()
target_process = self.scans_table[scan_id]['target_progress']
target_process[target] = progress
targets = self.scans_table[scan_id]['target_progress']
targets[target][host] = progress
# Set scan_info's target_progress to propagate progresses
# to parent process.
self.scans_table[scan_id]['target_progress'] = target_process
self.scans_table[scan_id]['target_progress'] = targets

def set_host_finished(self, scan_id, target, host):
""" Add the host in a list of finished hosts """
Expand All @@ -135,26 +143,26 @@ def ids_iterator(self):

return iter(self.scans_table.keys())

def create_scan(self, scan_id='', targets='', target_str=None,
options=dict(), vts=''):
def create_scan(self, scan_id='', targets='', options=None, vts=''):
""" Creates a new scan with provided scan information. """

if self.data_manager is None:
self.data_manager = multiprocessing.Manager()
if not options:
options = Dict()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
options = Dict()
options = dict()

scan_info = self.data_manager.dict()
scan_info['results'] = list()
scan_info['finished_hosts'] = dict(
[[target, []] for target, _, _ in targets])
scan_info['progress'] = 0
scan_info['target_progress'] = dict(
[[target, 0] for target, _, _ in targets])
[[target, {}] for target, _, _ in targets])
scan_info['targets'] = targets
scan_info['legacy_target'] = target_str
scan_info['vts'] = vts
scan_info['options'] = options
scan_info['start_time'] = int(time.time())
scan_info['end_time'] = "0"
scan_info['status'] = ""
scan_info['status'] = ScanStatus.INIT
if scan_id is None or scan_id == '':
scan_id = str(uuid.uuid4())
scan_info['scan_id'] = scan_id
Expand Down Expand Up @@ -185,10 +193,15 @@ def get_progress(self, scan_id):

return self.scans_table[scan_id]['progress']

def get_target_progress(self, scan_id):
""" Get a scan's current progress value. """
def get_target_progress(self, scan_id, target):
""" Get a target's current progress value.
The value is calculated with the progress of each single host
in the target."""

return self.scans_table[scan_id]['target_progress']
total_hosts = len(target_str_to_list(target))
host_progresses = self.scans_table[scan_id]['target_progress'].get(target)
t_prog = sum(host_progresses.values()) / total_hosts
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you check and ensure that total_hosts will never be zero?

return t_prog

def get_start_time(self, scan_id):
""" Get a scan's start time. """
Expand All @@ -202,14 +215,11 @@ def get_end_time(self, scan_id):

def get_target(self, scan_id):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally I would rename the method to get_targets or get_target_list.

""" Get a scan's target list. """
if self.scans_table[scan_id]['legacy_target']:
return self.scans_table[scan_id]['legacy_target']

target_list = []
for item in self.scans_table[scan_id]['targets']:
target_list.append(item[0])
separ = ','
return separ.join(target_list)
for target, _, _ in self.scans_table[scan_id]['targets']:
target_list.append(target)
return target_list

def get_ports(self, scan_id, target):
""" Get a scan's ports list. If a target is specified
Expand Down Expand Up @@ -246,7 +256,7 @@ def id_exists(self, scan_id):
def delete_scan(self, scan_id):
""" Delete a scan if fully finished. """

if self.get_status(scan_id) == "running":
if self.get_status(scan_id) == ScanStatus.RUNNING:
return False
self.scans_table.pop(scan_id)
if len(self.scans_table) == 0:
Expand Down
58 changes: 33 additions & 25 deletions ospd/ospd.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from ospd.vtfilter import VtsFilter
from ospd.misc import ScanCollection, ResultType, target_str_to_list
from ospd.misc import resolve_hostname, valid_uuid
from ospd.misc import ScanStatus
from ospd.xml import simple_response_str, get_result_xml
from ospd.error import OSPDError

Expand Down Expand Up @@ -603,8 +604,7 @@ def handle_start_scan_command(self, scan_et):
scan_params = self.process_scan_params(params)

scan_id = self.create_scan(scan_id, scan_targets,
target_str, scan_params,
vt_selection)
scan_params, vt_selection)
scan_process = multiprocessing.Process(target=scan_func,
args=(scan_id,
scan_targets,
Expand Down Expand Up @@ -635,7 +635,7 @@ def stop_scan(self, scan_id):
if not scan_process.is_alive():
raise OSPDError('Scan already stopped or finished.', 'stop_scan')

self.set_scan_status(scan_id, "stopped")
self.set_scan_status(scan_id, ScanStatus.STOPPED)
logger.info('%s: Scan stopping %s.', scan_id, scan_process.ident)
self.stop_scan_cleanup(scan_id)
try:
Expand All @@ -661,7 +661,7 @@ def exec_scan(self, scan_id, target):
def finish_scan(self, scan_id):
""" Sets a scan as finished. """
self.set_scan_progress(scan_id, 100)
self.set_scan_status(scan_id, 'finished')
self.set_scan_status(scan_id, ScanStatus.FINISHED)
logger.info("%s: Scan finished.", scan_id)

def get_daemon_name(self):
Expand Down Expand Up @@ -837,8 +837,9 @@ def check_pending_target(self, scan_id, multiscan_proc):
"""
for running_target_proc, running_target_id in multiscan_proc:
if not running_target_proc.is_alive():
target_prog = self.get_scan_target_progress(scan_id)
if target_prog[running_target_id] < 100:
target_prog = self.get_scan_target_progress(
scan_id, running_target_id)
if target_prog < 100:
self.stop_scan(scan_id)
running_target = (running_target_proc, running_target_id)
multiscan_proc.remove(running_target)
Expand All @@ -848,8 +849,10 @@ def calculate_progress(self, scan_id):
""" Calculate the total scan progress from the
partial target progress. """

target_progress = self.get_scan_target_progress(scan_id)
return sum(target_progress.values())/len(target_progress)
t_prog = dict()
for target in self.get_scan_target(scan_id):
t_prog[target] = self.get_scan_target_progress(scan_id, target)
return sum(t_prog.values())/len(t_prog)

def start_scan(self, scan_id, targets, parallel=1):
""" Handle N parallel scans if 'parallel' is greater than 1. """
Expand All @@ -863,22 +866,22 @@ def start_scan(self, scan_id, targets, parallel=1):

for index, target in enumerate(target_list):
while len(multiscan_proc) >= parallel:
multiscan_proc = self.check_pending_target(scan_id,
multiscan_proc)
progress = self.calculate_progress(scan_id)
self.set_scan_progress(scan_id, progress)
multiscan_proc = self.check_pending_target(scan_id,
multiscan_proc)
time.sleep(1)

#If the scan status is stopped, does not launch anymore target scans
if self.get_scan_status(scan_id) == "stopped":
if self.get_scan_status(scan_id) == ScanStatus.STOPPED:
return

logger.info("%s: Host scan started on ports %s.", target[0], target[1])
scan_process = multiprocessing.Process(target=self.parallel_scan,
args=(scan_id, target[0]))
multiscan_proc.append((scan_process, target[0]))
scan_process.start()
self.set_scan_status(scan_id, "running")
self.set_scan_status(scan_id, ScanStatus.RUNNING)

# Wait until all single target were scanned
while multiscan_proc:
Expand All @@ -889,14 +892,13 @@ def start_scan(self, scan_id, targets, parallel=1):
time.sleep(1)

# Only set the scan as finished if the scan was not stopped.
if self.get_scan_status(scan_id) != "stopped":
if self.get_scan_status(scan_id) != ScanStatus.STOPPED:
self.finish_scan(scan_id)

def dry_run_scan(self, scan_id, targets):
""" Dry runs a scan. """

os.setsid()
#target_list = target_str_to_list(target_str)
for _, target in enumerate(targets):
host = resolve_hostname(target[0])
if host is None:
Expand All @@ -919,12 +921,15 @@ def set_scan_host_finished(self, scan_id, target, host):
self.scan_collection.set_host_finished(scan_id, target, host)

def set_scan_progress(self, scan_id, progress):
""" Sets scan_id scan's progress which is a number between 0 and 100. """
""" Sets scan_id scan's progress which is a number
between 0 and 100. """
self.scan_collection.set_progress(scan_id, progress)

def set_scan_target_progress(self, scan_id, target, progress):
""" Sets target's progress. """
self.scan_collection.set_target_progress(scan_id, target, progress)
def set_scan_target_progress(
self, scan_id, target, host, progress):
""" Sets host's progress which is part of target. """
self.scan_collection.set_target_progress(
scan_id, target, host, progress)

def set_scan_status(self, scan_id, status):
""" Set the scan's status."""
Expand Down Expand Up @@ -1072,6 +1077,9 @@ def delete_scan(self, scan_id):

@return: 1 if scan deleted, 0 otherwise.
"""
if self.get_scan_status(scan_id) == ScanStatus.RUNNING:
return 0

try:
del self.scan_processes[scan_id]
except KeyError:
Expand Down Expand Up @@ -1120,7 +1128,7 @@ def get_scan_xml(self, scan_id, detailed=True, pop_res=False):
if not scan_id:
return Element('scan')

target = self.get_scan_target(scan_id)
target = ','.join(self.get_scan_target(scan_id))
progress = self.get_scan_progress(scan_id)
status = self.get_scan_status(scan_id)
start_time = self.get_scan_start_time(scan_id)
Expand All @@ -1129,7 +1137,7 @@ def get_scan_xml(self, scan_id, detailed=True, pop_res=False):
for name, value in [('id', scan_id),
('target', target),
('progress', progress),
('status', status),
('status', status.name.lower()),
('start_time', start_time),
('end_time', end_time)]:
response.set(name, str(value))
Expand Down Expand Up @@ -1555,15 +1563,15 @@ def scheduler(self):
""" Should be implemented by subclass in case of need
to run tasks periodically. """

def create_scan(self, scan_id, targets, target_str, options, vts):
def create_scan(self, scan_id, targets, options, vts):
""" Creates a new scan.

@target: Target to scan.
@options: Miscellaneous scan options.

@return: New scan's ID.
"""
return self.scan_collection.create_scan(scan_id, targets, target_str, options, vts)
return self.scan_collection.create_scan(scan_id, targets, options, vts)

def get_scan_options(self, scan_id):
""" Gives a scan's list of options. """
Expand All @@ -1578,7 +1586,7 @@ def check_scan_process(self, scan_id):
scan_process = self.scan_processes[scan_id]
progress = self.get_scan_progress(scan_id)
if progress < 100 and not scan_process.is_alive():
self.set_scan_status(scan_id, 'stopped')
self.set_scan_status(scan_id, ScanStatus.STOPPED)
self.add_scan_error(scan_id, name="", host="",
value="Scan process failure.")
logger.info("%s: Scan stopped with errors.", scan_id)
Expand All @@ -1589,9 +1597,9 @@ def get_scan_progress(self, scan_id):
""" Gives a scan's current progress value. """
return self.scan_collection.get_progress(scan_id)

def get_scan_target_progress(self, scan_id):
def get_scan_target_progress(self, scan_id, target):
""" Gives a list with scan's current progress value of each target. """
return self.scan_collection.get_target_progress(scan_id)
return self.scan_collection.get_target_progress(scan_id, target)

def get_scan_target(self, scan_id):
""" Gives a scan's target. """
Expand Down
8 changes: 4 additions & 4 deletions tests/testSSHDaemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def testNoParamiko(self):
def testRunCommand(self):
ospd_ssh.paramiko = fakeparamiko
daemon = OSPDaemonSimpleSSH('cert', 'key', 'ca')
scanid = daemon.create_scan(None, [['host.example.com', '80, 443', ''],], '',
scanid = daemon.create_scan(None, [['host.example.com', '80, 443', ''],],
dict(port=5, ssh_timeout=15,
username_password='dummy:pw'), '')
res = daemon.run_command(scanid, 'host.example.com', 'cat /etc/passwd')
Expand All @@ -79,7 +79,7 @@ def testRunCommand(self):
def testRunCommandLegacyCredential(self):
ospd_ssh.paramiko = fakeparamiko
daemon = OSPDaemonSimpleSSH('cert', 'key', 'ca')
scanid = daemon.create_scan(None, [['host.example.com', '80, 443', ''],], '',
scanid = daemon.create_scan(None, [['host.example.com', '80, 443', ''],],
dict(port=5, ssh_timeout=15,
username='dummy', password='pw'), '')
res = daemon.run_command(scanid, 'host.example.com', 'cat /etc/passwd')
Expand All @@ -99,7 +99,7 @@ def testRunCommandNewCredential(self):
'username': 'smbuser'}}

scanid = daemon.create_scan(None,
[['host.example.com', '80, 443', cred_dict],], '',
[['host.example.com', '80, 443', cred_dict],],
dict(port=5, ssh_timeout=15), '')
res = daemon.run_command(scanid, 'host.example.com', 'cat /etc/passwd')
self.assertTrue(isinstance(res, list))
Expand All @@ -108,7 +108,7 @@ def testRunCommandNewCredential(self):
def testRunCommandNoCredential(self):
ospd_ssh.paramiko = fakeparamiko
daemon = OSPDaemonSimpleSSH('cert', 'key', 'ca')
scanid = daemon.create_scan(None, [['host.example.com', '80, 443', ''],], '',
scanid = daemon.create_scan(None, [['host.example.com', '80, 443', ''],],
dict(port=5, ssh_timeout=15), '')
self.assertRaises(ValueError, daemon.run_command,
scanid, 'host.example.com', 'cat /etc/passwd' )
21 changes: 3 additions & 18 deletions tests/testScanAndResult.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def testScanWithError(self):
self.assertEqual(1, len(scans))
scan = scans[0]
status = scan.get('status')
if not status or status == 'running':
if status == "init" or status == "running":
self.assertEqual('0', scan.get('end_time'))
time.sleep(.010)
else:
Expand Down Expand Up @@ -696,21 +696,6 @@ def testScanGetTarget(self):
scan_res = response.find('scan')
self.assertEqual(scan_res.get('target'), 'localhosts,192.168.0.0/24')

def testScanGetLegacyTarget(self):
daemon = DummyWrapper([])

response = secET.fromstring(
daemon.handle_command(
'<start_scan target="localhosts,192.168.0.0/24" ports="22">'
'<scanner_params /><vts><vt id="1.2.3.4" />'
'</vts>'
'</start_scan>'))
scan_id = response.findtext('id')
response = secET.fromstring(
daemon.handle_command('<get_scans scan_id="%s"/>' % scan_id))
scan_res = response.find('scan')
self.assertEqual(scan_res.get('target'), 'localhosts,192.168.0.0/24')

def testScanMultiTargetParallelWithError(self):
daemon = DummyWrapper([])
cmd = secET.fromstring('<start_scan parallel="100a">'
Expand Down Expand Up @@ -750,8 +735,8 @@ def testProgress(self):
'</target></targets>'
'</start_scan>'))
scan_id = response.findtext('id')
daemon.set_scan_target_progress(scan_id, 'localhost1', 75)
daemon.set_scan_target_progress(scan_id, 'localhost2', 25)
daemon.set_scan_target_progress(scan_id, 'localhost1', 'localhost1', 75)
daemon.set_scan_target_progress(scan_id, 'localhost2', 'localhost2', 25)
self.assertEqual(daemon.calculate_progress(scan_id), 50)

def testSetGetVtsVersion(self):
Expand Down