diff --git a/gvm/errors.py b/gvm/errors.py index d7c80da23..57266b500 100644 --- a/gvm/errors.py +++ b/gvm/errors.py @@ -21,6 +21,15 @@ class GvmError(Exception): - """A exception for gvm errors + """An exception for gvm errors """ pass + + +class InvalidArgument(GvmError): + """Raised if an invalid argument/parameter is passed + """ + +class RequiredArgument(GvmError): + """Raised if a required argument/parameter is missing + """ diff --git a/gvm/protocols/base.py b/gvm/protocols/base.py index 555426905..135b264f4 100644 --- a/gvm/protocols/base.py +++ b/gvm/protocols/base.py @@ -80,6 +80,14 @@ def _transform(self, data): return data return transform(data) + def _send_xml_command(self, xmlcmd): + """Send a xml command to the remote server + + Arguments: + xmlcmd (gvm.xml.XmlCommand): XmlCommand instance to send + """ + return self.send_command(xmlcmd.to_string()) + def is_connected(self): """Status of the current connection diff --git a/gvm/protocols/gmpv7.py b/gvm/protocols/gmpv7.py index a9368e349..a9564a4fd 100644 --- a/gvm/protocols/gmpv7.py +++ b/gvm/protocols/gmpv7.py @@ -22,8 +22,9 @@ from lxml import etree +from gvm.errors import InvalidArgument, RequiredArgument from gvm.utils import get_version_string -from gvm.xml import _GmpCommandFactory as GmpCommandFactory +from gvm.xml import _GmpCommandFactory as GmpCommandFactory, XmlCommand from .base import GvmProtocol @@ -31,6 +32,30 @@ PROTOCOL_VERSION = (7,) +FILTER_TYPES = [ + 'agent', + 'alert', + 'asset', + 'config', + 'credential', + 'filter', + 'group', + 'note', + 'override', + 'permission', + 'port_list', + 'report', + 'report_format', + 'result', + 'role', + 'schedule', + 'secinfo', + 'tag', + 'target', + 'task', + 'user', +] + def _check_command_status(xml): """Check gmp response @@ -163,10 +188,45 @@ def create_credential(self, name, **kwargs): cmd = self._generator.create_credential_command(name, kwargs) return self.send_command(cmd) - def create_filter(self, name, make_unique, **kwargs): - cmd = self._generator.create_filter_command(name, make_unique, - kwargs) - return self.send_command(cmd) + def create_filter(self, name, make_unique=False, filter_type=None, + comment=None, term=None, copy=None): + """Create a new filter + + Arguments: + name (str): Name of the new filter + make_unique (Boolean): + filter_type (str): Filter for entity type + comment (str): Comment for the filter + term (str): Filter term e.g. 'name=foo' + copy (str): UUID of an existing filter + """ + if not name: + raise RequiredArgument('create_filter requires a name argument') + + cmd = XmlCommand('create_filter') + _xmlname = cmd.add_element('name', name) + if make_unique: + _xmlname.add_element('make_unique', '1') + + if comment: + cmd.add_element('comment', comment) + + # TODO: Move copy into an extra method + if copy: + cmd.add_element('copy', copy) + + if term: + cmd.add_element('term', term) + + if filter_type: + filter_type = filter_type.lower() + if filter_type not in FILTER_TYPES: + raise InvalidArgument( + 'create_filter requires type to be one of {0} but ' + 'was {1}'.format(', '.join(FILTER_TYPES), filter_type)) + cmd.add_element('type', filter_type) + + return self._send_xml_command(cmd) def create_group(self, name, **kwargs): cmd = self._generator.create_group_command(name, kwargs) diff --git a/gvm/xml.py b/gvm/xml.py index 7ac171d8b..7456e7e4d 100644 --- a/gvm/xml.py +++ b/gvm/xml.py @@ -20,30 +20,6 @@ from lxml import etree -FILTER_NAMES = [ - 'Agent', - 'Alert', - 'Asset', - 'Config', - 'Credential', - 'Filter', - 'Group', - 'Note', - 'Override', - 'Permission', - 'Port List', - 'Report', - 'Report Format', - 'Result', - 'Role', - 'Schedule', - 'SecInfo', - 'Tag', - 'Target', - 'Task', - 'User', -] - class XmlCommandElement: def __init__(self, element): @@ -263,37 +239,6 @@ def create_credential_command(self, name, kwargs): return cmd.to_string() - def create_filter_command(self, name, make_unique, kwargs): - """Generates xml string for create filter on gvmd.""" - - cmd = XmlCommand('create_filter') - _xmlname = cmd.add_element('name', name) - if make_unique: - _xmlname.add_element('make_unique', '1') - else: - _xmlname.add_element('make_unique', '0') - - comment = kwargs.get('comment', '') - if comment: - cmd.add_element('comment', comment) - - copy = kwargs.get('copy', '') - if copy: - cmd.add_element('copy', copy) - - term = kwargs.get('term', '') - if term: - cmd.add_element('term', term) - - filter_type = kwargs.get('type', '') - if filter_type: - if filter_type not in FILTER_NAMES: - raise ValueError('create_filter requires type ' - 'to be either cc, snmp, up or usk') - cmd.add_element('type', filter_type) - - return cmd.to_string() - def create_group_command(self, name, kwargs): """Generates xml string for create group on gvmd.""" diff --git a/tests/__init__.py b/tests/__init__.py index e69de29bb..4a3d5c0b8 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +# Copyright (C) 2018 Greenbone Networks GmbH +# +# SPDX-License-Identifier: GPL-3.0-or-later +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import functools + + +class CallableMock: + + def __init__(self, func): + # look like the function we are wrapping + functools.update_wrapper(self, func) + self.calls = [] + self.func = func + self.result = None + + def __call__(self, *args, **kwargs): + self.calls.append({'args': args, 'kwargs': kwargs}) + + if self.result is None and not self.func is None: + return self.func(self, *args, **kwargs) + + return self.result + + def return_value(self, value): + self.result = value + + def has_been_called(self): + assert len(self.calls) > 0, "{0} havn't been called.".format( + self.func.__name__) + + def has_been_called_times(self, times): + assert len(self.calls) == times, "{name} haven't been called {times}" \ + " times.".format(name=self.func.__name__, times=times) + + def has_been_called_with(self, *args, **kwargs): + if len(self.calls) == 0: + assert False + + lastcall = self.calls[-1] + + # not sure if this is correct + assert lastcall['args'] == args and lastcall['kwargs'] == kwargs, \ + "Expected arguments {eargs} {ekwargs} of {name} don't match." \ + "Received: {rargs} {rkwargs}".format( + name=self.func.__name__, + eargs=args, + ekwargs=kwargs, + rargs=lastcall['args'], + rkwargs=lastcall['kwargs'] + ) diff --git a/tests/protocols/__init__.py b/tests/protocols/__init__.py index e69de29bb..feb5f1f9f 100644 --- a/tests/protocols/__init__.py +++ b/tests/protocols/__init__.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +# Copyright (C) 2018 Greenbone Networks GmbH +# +# SPDX-License-Identifier: GPL-3.0-or-later +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from tests import CallableMock + + +class MockConnection: + + @CallableMock + def connect(self): + pass + + @CallableMock + def send(self, data): + pass + + @CallableMock + def read(self): + return '' + + @CallableMock + def disconnect(self): + pass diff --git a/tests/protocols/gmp/test_gmp_create_filter.py b/tests/protocols/gmp/test_gmp_create_filter.py index 8bdb4ad86..fc0edc7ec 100644 --- a/tests/protocols/gmp/test_gmp_create_filter.py +++ b/tests/protocols/gmp/test_gmp_create_filter.py @@ -16,36 +16,67 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . +# pylint can't read decorators => disable member checks +# pylint: disable=no-member + import unittest -from gvm.xml import _GmpCommandFactory as GmpCommandFactory, FILTER_NAMES +from gvm.errors import InvalidArgument +from gvm.protocols.gmpv7 import Gmp, FILTER_TYPES +from .. import MockConnection class GMPCreateFilterCommandTestCase(unittest.TestCase): FILTER_NAME = "special filter" def setUp(self): - self.gmp = GmpCommandFactory() - - def test_all_available_filters_correct_cmd(self): - for filter_type in FILTER_NAMES: - cmd = self.gmp.create_filter_command( - name=self.FILTER_NAME, make_unique=True, - kwargs={ - 'term': 'sort-reverse=threat result_hosts_only=1 ' - 'notes=1 overrides=1 levels=hml first=1 rows=1000', - 'type': filter_type - }) - - self.assertEqual( + self.connection = MockConnection() + self.gmp = Gmp(self.connection) + + def test_all_available_filters_types_correct(self): + for filter_type in FILTER_TYPES: + self.gmp.create_filter( + name=self.FILTER_NAME, + term='sort-reverse=threat first=1 rows=1000', + filter_type=filter_type, + ) + + self.connection.send.has_been_called_with( '' - '{0}1' - 'sort-reverse=threat result_hosts_only=1 notes=1 ' - 'overrides=1 levels=hml first=1 rows=1000' + '{0}' + 'sort-reverse=threat first=1 rows=1000' '{1}' ''.format(self.FILTER_NAME, filter_type), - cmd) + ) + + def test_invalid_filters_type(self): + with self.assertRaises(InvalidArgument): + self.gmp.create_filter( + name=self.FILTER_NAME, + term='sort-reverse=threat result_hosts_only=1 ' + 'notes=1 overrides=1 levels=hml first=1 rows=1000', + filter_type='foo', + ) + + def test_all_arguments(self): + self.gmp.create_filter( + name=self.FILTER_NAME, make_unique=True, + term='sort-reverse=threat result_hosts_only=1 ' + 'notes=1 overrides=1 levels=hml first=1 rows=1000', + filter_type='task', + comment='foo', + ) + + self.connection.send.has_been_called_with( + '' + '{0}1' + 'foo' + 'sort-reverse=threat result_hosts_only=1 notes=1 ' + 'overrides=1 levels=hml first=1 rows=1000' + 'task' + ''.format(self.FILTER_NAME), + ) if __name__ == '__main__': diff --git a/tests/protocols/gmp/test_gmp_get_version.py b/tests/protocols/gmp/test_gmp_get_version.py index e89cd3a69..f9875fc51 100644 --- a/tests/protocols/gmp/test_gmp_get_version.py +++ b/tests/protocols/gmp/test_gmp_get_version.py @@ -16,22 +16,28 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . +# pylint can't read decorators => disable member checks +# pylint: disable=no-member + import unittest -from gvm.xml import _GmpCommandFactory as GmpCommandFactory +from gvm.protocols.gmpv7 import Gmp +from .. import MockConnection -class GMPGetVersionCommandTestCase(unittest.TestCase): - def setUp(self): - self.gmp = GmpCommandFactory() +class GmpGetVersionCommandTestCase(unittest.TestCase): - def tearDown(self): - pass + def setUp(self): + self.connection = MockConnection() + self.gmp = Gmp(self.connection) def test_get_version(self): - cmd = self.gmp.get_version_command() + self.gmp.get_version() - self.assertEqual('', cmd) + self.connection.connect.has_been_called() + self.connection.read.has_been_called() + self.connection.send.has_been_called() + self.connection.send.has_been_called_with('') if __name__ == '__main__':