diff --git a/gvm/errors.py b/gvm/errors.py
index d7c80da23..57266b500 100644
--- a/gvm/errors.py
+++ b/gvm/errors.py
@@ -21,6 +21,15 @@
class GvmError(Exception):
- """A exception for gvm errors
+ """An exception for gvm errors
"""
pass
+
+
+class InvalidArgument(GvmError):
+ """Raised if an invalid argument/parameter is passed
+ """
+
+class RequiredArgument(GvmError):
+ """Raised if a required argument/parameter is missing
+ """
diff --git a/gvm/protocols/base.py b/gvm/protocols/base.py
index 555426905..135b264f4 100644
--- a/gvm/protocols/base.py
+++ b/gvm/protocols/base.py
@@ -80,6 +80,14 @@ def _transform(self, data):
return data
return transform(data)
+ def _send_xml_command(self, xmlcmd):
+ """Send a xml command to the remote server
+
+ Arguments:
+ xmlcmd (gvm.xml.XmlCommand): XmlCommand instance to send
+ """
+ return self.send_command(xmlcmd.to_string())
+
def is_connected(self):
"""Status of the current connection
diff --git a/gvm/protocols/gmpv7.py b/gvm/protocols/gmpv7.py
index a9368e349..a9564a4fd 100644
--- a/gvm/protocols/gmpv7.py
+++ b/gvm/protocols/gmpv7.py
@@ -22,8 +22,9 @@
from lxml import etree
+from gvm.errors import InvalidArgument, RequiredArgument
from gvm.utils import get_version_string
-from gvm.xml import _GmpCommandFactory as GmpCommandFactory
+from gvm.xml import _GmpCommandFactory as GmpCommandFactory, XmlCommand
from .base import GvmProtocol
@@ -31,6 +32,30 @@
PROTOCOL_VERSION = (7,)
+FILTER_TYPES = [
+ 'agent',
+ 'alert',
+ 'asset',
+ 'config',
+ 'credential',
+ 'filter',
+ 'group',
+ 'note',
+ 'override',
+ 'permission',
+ 'port_list',
+ 'report',
+ 'report_format',
+ 'result',
+ 'role',
+ 'schedule',
+ 'secinfo',
+ 'tag',
+ 'target',
+ 'task',
+ 'user',
+]
+
def _check_command_status(xml):
"""Check gmp response
@@ -163,10 +188,45 @@ def create_credential(self, name, **kwargs):
cmd = self._generator.create_credential_command(name, kwargs)
return self.send_command(cmd)
- def create_filter(self, name, make_unique, **kwargs):
- cmd = self._generator.create_filter_command(name, make_unique,
- kwargs)
- return self.send_command(cmd)
+ def create_filter(self, name, make_unique=False, filter_type=None,
+ comment=None, term=None, copy=None):
+ """Create a new filter
+
+ Arguments:
+ name (str): Name of the new filter
+ make_unique (Boolean):
+ filter_type (str): Filter for entity type
+ comment (str): Comment for the filter
+ term (str): Filter term e.g. 'name=foo'
+ copy (str): UUID of an existing filter
+ """
+ if not name:
+ raise RequiredArgument('create_filter requires a name argument')
+
+ cmd = XmlCommand('create_filter')
+ _xmlname = cmd.add_element('name', name)
+ if make_unique:
+ _xmlname.add_element('make_unique', '1')
+
+ if comment:
+ cmd.add_element('comment', comment)
+
+ # TODO: Move copy into an extra method
+ if copy:
+ cmd.add_element('copy', copy)
+
+ if term:
+ cmd.add_element('term', term)
+
+ if filter_type:
+ filter_type = filter_type.lower()
+ if filter_type not in FILTER_TYPES:
+ raise InvalidArgument(
+ 'create_filter requires type to be one of {0} but '
+ 'was {1}'.format(', '.join(FILTER_TYPES), filter_type))
+ cmd.add_element('type', filter_type)
+
+ return self._send_xml_command(cmd)
def create_group(self, name, **kwargs):
cmd = self._generator.create_group_command(name, kwargs)
diff --git a/gvm/xml.py b/gvm/xml.py
index 7ac171d8b..7456e7e4d 100644
--- a/gvm/xml.py
+++ b/gvm/xml.py
@@ -20,30 +20,6 @@
from lxml import etree
-FILTER_NAMES = [
- 'Agent',
- 'Alert',
- 'Asset',
- 'Config',
- 'Credential',
- 'Filter',
- 'Group',
- 'Note',
- 'Override',
- 'Permission',
- 'Port List',
- 'Report',
- 'Report Format',
- 'Result',
- 'Role',
- 'Schedule',
- 'SecInfo',
- 'Tag',
- 'Target',
- 'Task',
- 'User',
-]
-
class XmlCommandElement:
def __init__(self, element):
@@ -263,37 +239,6 @@ def create_credential_command(self, name, kwargs):
return cmd.to_string()
- def create_filter_command(self, name, make_unique, kwargs):
- """Generates xml string for create filter on gvmd."""
-
- cmd = XmlCommand('create_filter')
- _xmlname = cmd.add_element('name', name)
- if make_unique:
- _xmlname.add_element('make_unique', '1')
- else:
- _xmlname.add_element('make_unique', '0')
-
- comment = kwargs.get('comment', '')
- if comment:
- cmd.add_element('comment', comment)
-
- copy = kwargs.get('copy', '')
- if copy:
- cmd.add_element('copy', copy)
-
- term = kwargs.get('term', '')
- if term:
- cmd.add_element('term', term)
-
- filter_type = kwargs.get('type', '')
- if filter_type:
- if filter_type not in FILTER_NAMES:
- raise ValueError('create_filter requires type '
- 'to be either cc, snmp, up or usk')
- cmd.add_element('type', filter_type)
-
- return cmd.to_string()
-
def create_group_command(self, name, kwargs):
"""Generates xml string for create group on gvmd."""
diff --git a/tests/__init__.py b/tests/__init__.py
index e69de29bb..4a3d5c0b8 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -0,0 +1,65 @@
+# -*- coding: utf-8 -*-
+# Copyright (C) 2018 Greenbone Networks GmbH
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+import functools
+
+
+class CallableMock:
+
+ def __init__(self, func):
+ # look like the function we are wrapping
+ functools.update_wrapper(self, func)
+ self.calls = []
+ self.func = func
+ self.result = None
+
+ def __call__(self, *args, **kwargs):
+ self.calls.append({'args': args, 'kwargs': kwargs})
+
+ if self.result is None and not self.func is None:
+ return self.func(self, *args, **kwargs)
+
+ return self.result
+
+ def return_value(self, value):
+ self.result = value
+
+ def has_been_called(self):
+ assert len(self.calls) > 0, "{0} havn't been called.".format(
+ self.func.__name__)
+
+ def has_been_called_times(self, times):
+ assert len(self.calls) == times, "{name} haven't been called {times}" \
+ " times.".format(name=self.func.__name__, times=times)
+
+ def has_been_called_with(self, *args, **kwargs):
+ if len(self.calls) == 0:
+ assert False
+
+ lastcall = self.calls[-1]
+
+ # not sure if this is correct
+ assert lastcall['args'] == args and lastcall['kwargs'] == kwargs, \
+ "Expected arguments {eargs} {ekwargs} of {name} don't match." \
+ "Received: {rargs} {rkwargs}".format(
+ name=self.func.__name__,
+ eargs=args,
+ ekwargs=kwargs,
+ rargs=lastcall['args'],
+ rkwargs=lastcall['kwargs']
+ )
diff --git a/tests/protocols/__init__.py b/tests/protocols/__init__.py
index e69de29bb..feb5f1f9f 100644
--- a/tests/protocols/__init__.py
+++ b/tests/protocols/__init__.py
@@ -0,0 +1,38 @@
+# -*- coding: utf-8 -*-
+# Copyright (C) 2018 Greenbone Networks GmbH
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from tests import CallableMock
+
+
+class MockConnection:
+
+ @CallableMock
+ def connect(self):
+ pass
+
+ @CallableMock
+ def send(self, data):
+ pass
+
+ @CallableMock
+ def read(self):
+ return ''
+
+ @CallableMock
+ def disconnect(self):
+ pass
diff --git a/tests/protocols/gmp/test_gmp_create_filter.py b/tests/protocols/gmp/test_gmp_create_filter.py
index 8bdb4ad86..fc0edc7ec 100644
--- a/tests/protocols/gmp/test_gmp_create_filter.py
+++ b/tests/protocols/gmp/test_gmp_create_filter.py
@@ -16,36 +16,67 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
+# pylint can't read decorators => disable member checks
+# pylint: disable=no-member
+
import unittest
-from gvm.xml import _GmpCommandFactory as GmpCommandFactory, FILTER_NAMES
+from gvm.errors import InvalidArgument
+from gvm.protocols.gmpv7 import Gmp, FILTER_TYPES
+from .. import MockConnection
class GMPCreateFilterCommandTestCase(unittest.TestCase):
FILTER_NAME = "special filter"
def setUp(self):
- self.gmp = GmpCommandFactory()
-
- def test_all_available_filters_correct_cmd(self):
- for filter_type in FILTER_NAMES:
- cmd = self.gmp.create_filter_command(
- name=self.FILTER_NAME, make_unique=True,
- kwargs={
- 'term': 'sort-reverse=threat result_hosts_only=1 '
- 'notes=1 overrides=1 levels=hml first=1 rows=1000',
- 'type': filter_type
- })
-
- self.assertEqual(
+ self.connection = MockConnection()
+ self.gmp = Gmp(self.connection)
+
+ def test_all_available_filters_types_correct(self):
+ for filter_type in FILTER_TYPES:
+ self.gmp.create_filter(
+ name=self.FILTER_NAME,
+ term='sort-reverse=threat first=1 rows=1000',
+ filter_type=filter_type,
+ )
+
+ self.connection.send.has_been_called_with(
''
- '{0}1'
- 'sort-reverse=threat result_hosts_only=1 notes=1 '
- 'overrides=1 levels=hml first=1 rows=1000'
+ '{0}'
+ 'sort-reverse=threat first=1 rows=1000'
'{1}'
''.format(self.FILTER_NAME, filter_type),
- cmd)
+ )
+
+ def test_invalid_filters_type(self):
+ with self.assertRaises(InvalidArgument):
+ self.gmp.create_filter(
+ name=self.FILTER_NAME,
+ term='sort-reverse=threat result_hosts_only=1 '
+ 'notes=1 overrides=1 levels=hml first=1 rows=1000',
+ filter_type='foo',
+ )
+
+ def test_all_arguments(self):
+ self.gmp.create_filter(
+ name=self.FILTER_NAME, make_unique=True,
+ term='sort-reverse=threat result_hosts_only=1 '
+ 'notes=1 overrides=1 levels=hml first=1 rows=1000',
+ filter_type='task',
+ comment='foo',
+ )
+
+ self.connection.send.has_been_called_with(
+ ''
+ '{0}1'
+ 'foo'
+ 'sort-reverse=threat result_hosts_only=1 notes=1 '
+ 'overrides=1 levels=hml first=1 rows=1000'
+ 'task'
+ ''.format(self.FILTER_NAME),
+ )
if __name__ == '__main__':
diff --git a/tests/protocols/gmp/test_gmp_get_version.py b/tests/protocols/gmp/test_gmp_get_version.py
index e89cd3a69..f9875fc51 100644
--- a/tests/protocols/gmp/test_gmp_get_version.py
+++ b/tests/protocols/gmp/test_gmp_get_version.py
@@ -16,22 +16,28 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
+# pylint can't read decorators => disable member checks
+# pylint: disable=no-member
+
import unittest
-from gvm.xml import _GmpCommandFactory as GmpCommandFactory
+from gvm.protocols.gmpv7 import Gmp
+from .. import MockConnection
-class GMPGetVersionCommandTestCase(unittest.TestCase):
- def setUp(self):
- self.gmp = GmpCommandFactory()
+class GmpGetVersionCommandTestCase(unittest.TestCase):
- def tearDown(self):
- pass
+ def setUp(self):
+ self.connection = MockConnection()
+ self.gmp = Gmp(self.connection)
def test_get_version(self):
- cmd = self.gmp.get_version_command()
+ self.gmp.get_version()
- self.assertEqual('', cmd)
+ self.connection.connect.has_been_called()
+ self.connection.read.has_been_called()
+ self.connection.send.has_been_called()
+ self.connection.send.has_been_called_with('')
if __name__ == '__main__':