diff --git a/azure-iot-device/azure/iot/device/common/auth/__init__.py b/azure-iot-device/azure/iot/device/common/auth/__init__.py new file mode 100644 index 000000000..7c55f882f --- /dev/null +++ b/azure-iot-device/azure/iot/device/common/auth/__init__.py @@ -0,0 +1,6 @@ +from .signing_mechanism import SymmetricKeySigningMechanism + +# NOTE: Please import the connection_string and sastoken modules directly +# rather than through the package interface, as the modules contain many +# related items for their respective domains, which we do not wish to expose +# at length here. diff --git a/azure-iot-device/azure/iot/device/common/connection_string.py b/azure-iot-device/azure/iot/device/common/auth/connection_string.py similarity index 82% rename from azure-iot-device/azure/iot/device/common/connection_string.py rename to azure-iot-device/azure/iot/device/common/auth/connection_string.py index 064b2a731..629514bbc 100644 --- a/azure-iot-device/azure/iot/device/common/connection_string.py +++ b/azure-iot-device/azure/iot/device/common/auth/connection_string.py @@ -32,8 +32,17 @@ def _parse_connection_string(connection_string): """Return a dictionary of values contained in a given connection string """ - cs_args = connection_string.split(CS_DELIMITER) - d = dict(arg.split(CS_VAL_SEPARATOR, 1) for arg in cs_args) + try: + cs_args = connection_string.split(CS_DELIMITER) + except (AttributeError, TypeError): + # NOTE: in Python 2.7, bytes will not raise an error here as they do in all other versions + raise TypeError("Connection String must be of type str") + try: + d = dict(arg.split(CS_VAL_SEPARATOR, 1) for arg in cs_args) + except ValueError: + # This occurs in an extreme edge case where a dictionary cannot be formed because there + # is only 1 token after the split (dict requires two in order to make a key/value pair) + raise ValueError("Invalid Connection String - Unable to parse") if len(cs_args) != len(d): # various errors related to incorrect parsing - duplicate args, bad syntax, etc. raise ValueError("Invalid Connection String - Unable to parse") diff --git a/azure-iot-device/azure/iot/device/common/auth/sastoken.py b/azure-iot-device/azure/iot/device/common/auth/sastoken.py new file mode 100644 index 000000000..291bd9be2 --- /dev/null +++ b/azure-iot-device/azure/iot/device/common/auth/sastoken.py @@ -0,0 +1,96 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +"""This module contains tools for working with Shared Access Signature (SAS) Tokens""" + +import base64 +import hmac +import hashlib +import time +import six.moves.urllib as urllib +from azure.iot.device.common.chainable_exception import ChainableException + + +class SasTokenError(ChainableException): + """Error in SasToken""" + + pass + + +class SasToken(object): + """Shared Access Signature Token used to authenticate a request + + Data Attributes: + expiry_time (int): Time that token will expire (in UTC, since epoch) + ttl (int): Time to live for the token, in seconds + """ + + _auth_rule_token_format = ( + "SharedAccessSignature sr={resource}&sig={signature}&se={expiry}&skn={keyname}" + ) + _simple_token_format = "SharedAccessSignature sr={resource}&sig={signature}&se={expiry}" + + def __init__(self, uri, signing_mechanism, key_name=None, ttl=3600): + """ + :param str uri: URI of the resouce to be accessed + :param signing_mechanism: The signing mechanism to use in the SasToken + :type signing_mechanism: Child classes of :class:`azure.iot.common.SigningMechanism` + :param str key_name: Symmetric Key Name (optional) + :param int ttl: Time to live for the token, in seconds (default 3600) + + :raises: SasTokenError if an error occurs building a SasToken + """ + self._uri = uri + self._signing_mechanism = signing_mechanism + self._key_name = key_name + self._expiry_time = None # This will be overwritten by the .refresh() call below + self._token = None # This will be overwritten by the .refresh() call below + + self.ttl = ttl + self.refresh() + + def __str__(self): + return self._token + + def refresh(self): + """ + Refresh the SasToken lifespan, giving it a new expiry time, and generating a new token. + """ + self._expiry_time = int(time.time() + self.ttl) + self._token = self._build_token() + + def _build_token(self): + """Buid SasToken representation + + :returns: String representation of the token + """ + url_encoded_uri = urllib.parse.quote(self._uri, safe="") + message = url_encoded_uri + "\n" + str(self.expiry_time) + try: + signature = self._signing_mechanism.sign(message) + except Exception as e: + # Because of variant signing mechanisms, we don't know what error might be raised. + # So we catch all of them. + raise SasTokenError("Unable to build SasToken from given values", e) + url_encoded_signature = urllib.parse.quote(signature, safe="") + if self._key_name: + token = self._auth_rule_token_format.format( + resource=url_encoded_uri, + signature=url_encoded_signature, + expiry=str(self.expiry_time), + keyname=self._key_name, + ) + else: + token = self._simple_token_format.format( + resource=url_encoded_uri, + signature=url_encoded_signature, + expiry=str(self.expiry_time), + ) + return token + + @property + def expiry_time(self): + """Expiry Time is READ ONLY""" + return self._expiry_time diff --git a/azure-iot-device/azure/iot/device/common/auth/signing_mechanism.py b/azure-iot-device/azure/iot/device/common/auth/signing_mechanism.py new file mode 100644 index 000000000..4bc634a02 --- /dev/null +++ b/azure-iot-device/azure/iot/device/common/auth/signing_mechanism.py @@ -0,0 +1,73 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +"""This module defines an abstract SigningMechanism, as well as common child implementations of it +""" + +import six +import abc +import hmac +import hashlib +import base64 +from six.moves import urllib + + +@six.add_metaclass(abc.ABCMeta) +class SigningMechanism(object): + @abc.abstractmethod + def sign(self, data_str): + pass + + +class SymmetricKeySigningMechanism(SigningMechanism): + def __init__(self, key): + """ + A mechanism that signs data using a symmetric key + + :param key: Symmetric Key (base64 encoded) + :type key: str or bytes + """ + # Convert key to bytes + try: + key = key.encode("utf-8") + except AttributeError: + # If byte string, no need to encode + pass + + # Derives the signing key + # CT-TODO: is "signing key" the right term? + try: + self._signing_key = base64.b64decode(key) + except (base64.binascii.Error, TypeError): + # NOTE: TypeError can only be raised in Python 2.7 + raise ValueError("Invalid Symmetric Key") + + def sign(self, data_str): + """ + Sign a data string with symmetric key and the HMAC-SHA256 algorithm. + + :param data_str: Data string to be signed + :type data_str: str or bytes + + :returns: The signed data + :rtype: str + """ + # Convert data_str to bytes + try: + data_str = data_str.encode("utf-8") + except AttributeError: + # If byte string, no need to encode + pass + + # Derive signature via HMAC-SHA256 algorithm + try: + hmac_digest = hmac.HMAC( + key=self._signing_key, msg=data_str, digestmod=hashlib.sha256 + ).digest() + signed_data = base64.b64encode(hmac_digest) + except (TypeError): + raise ValueError("Unable to sign string using the provided symmetric key") + # Convert from bytes to string + return signed_data.decode("utf-8") diff --git a/azure-iot-device/azure/iot/device/common/callable_weak_method.py b/azure-iot-device/azure/iot/device/common/callable_weak_method.py index dfb9ba2b3..a543b6e6a 100644 --- a/azure-iot-device/azure/iot/device/common/callable_weak_method.py +++ b/azure-iot-device/azure/iot/device/common/callable_weak_method.py @@ -63,7 +63,7 @@ def __init__(self, object, method_name): self.method_name = method_name def _get_method(self): - return getattr(self.object_weakref(), self.method_name) + return getattr(self.object_weakref(), self.method_name, None) def __call__(self, *args, **kwargs): return self._get_method()(*args, **kwargs) diff --git a/azure-iot-device/azure/iot/device/common/http_transport.py b/azure-iot-device/azure/iot/device/common/http_transport.py index 7aa81f4ab..7fffbbbfe 100644 --- a/azure-iot-device/azure/iot/device/common/http_transport.py +++ b/azure-iot-device/azure/iot/device/common/http_transport.py @@ -88,6 +88,11 @@ def request(self, method, path, callback, body="", headers={}, query_params=""): logger.debug("connecting to host tcp socket") connection.connect() logger.debug("connection succeeded") + # TODO: URL formation should be moved to pipeline_stages_iothub_http, I believe, as + # depending on the operation this could have a different hostname, due to different + # destinations. For now this isn't a problem yet, because no possible client can + # support more than one HTTP operation + # (Device can do File Upload but NOT Method Invoke, Module can do Method Inovke and NOT file upload) url = "https://{hostname}/{path}{query_params}".format( hostname=self._hostname, path=path, diff --git a/azure-iot-device/azure/iot/device/common/mqtt_transport.py b/azure-iot-device/azure/iot/device/common/mqtt_transport.py index 9e492dd08..59d98c2f3 100644 --- a/azure-iot-device/azure/iot/device/common/mqtt_transport.py +++ b/azure-iot-device/azure/iot/device/common/mqtt_transport.py @@ -155,6 +155,7 @@ def _create_mqtt_client(self): ) if self._proxy_options: + logger.info("Setting custom proxy options on mqtt client") mqtt_client.proxy_set( proxy_type=self._proxy_options.proxy_type, proxy_addr=self._proxy_options.proxy_address, @@ -325,12 +326,15 @@ def _create_ssl_context(self): ssl_context = ssl.SSLContext(protocol=ssl.PROTOCOL_TLSv1_2) if self._server_verification_cert: + logger.debug("configuring SSL context with custom server verification cert") ssl_context.load_verify_locations(cadata=self._server_verification_cert) else: + logger.debug("configuring SSL context with default certs") ssl_context.load_default_certs() if self._cipher: try: + logger.debug("configuring SSL context with cipher suites") ssl_context.set_ciphers(self._cipher) except ssl.SSLError as e: # TODO: custom error with more detail? diff --git a/azure-iot-device/azure/iot/device/common/pipeline/config.py b/azure-iot-device/azure/iot/device/common/pipeline/config.py index 759cb7014..3f46c3a12 100644 --- a/azure-iot-device/azure/iot/device/common/pipeline/config.py +++ b/azure-iot-device/azure/iot/device/common/pipeline/config.py @@ -19,9 +19,28 @@ class BasePipelineConfig(object): config files. """ - def __init__(self, websockets=False, cipher="", proxy_options=None): + def __init__( + self, + hostname, + gateway_hostname=None, + sastoken=None, + x509=None, + server_verification_cert=None, + websockets=False, + cipher="", + proxy_options=None, + ): """Initializer for BasePipelineConfig + :param str hostname: The hostname being connected to + :param str gateway_hostname: The gateway hostname optionally being used + :param sastoken: SasToken to be used for authentication. Mutually exclusive with x509. + :type sastoken: :class:`azure.iot.device.common.auth.SasToken` + :param x509: X509 to be used for authentication. Mutually exclusive with sastoken. + :type x509: :class:`azure.iot.device.models.X509` + :param str server_verification_cert: The trusted certificate chain. + Necessary when using connecting to an endpoint which has a non-standard root of trust, + such as a protocol gateway. :param bool websockets: Enabling/disabling websockets in MQTT. This feature is relevant if a firewall blocks port 8883 from use. :param cipher: Optional cipher suite(s) for TLS/SSL, as a string in @@ -30,6 +49,16 @@ def __init__(self, websockets=False, cipher="", proxy_options=None): :param proxy_options: Details of proxy configuration :type proxy_options: :class:`azure.iot.device.common.models.ProxyOptions` """ + # Network + self.hostname = hostname + self.gateway_hostname = gateway_hostname + + # Auth + self.sastoken = sastoken + self.x509 = x509 + if (not sastoken and not x509) or (sastoken and x509): + raise ValueError("One of either 'sastoken' or 'x509' must be provided") + self.server_verification_cert = server_verification_cert self.websockets = websockets self.cipher = self._sanitize_cipher(cipher) self.proxy_options = proxy_options diff --git a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_ops_base.py b/azure-iot-device/azure/iot/device/common/pipeline/pipeline_ops_base.py index 04ab66736..c82d370f9 100644 --- a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_ops_base.py +++ b/azure-iot-device/azure/iot/device/common/pipeline/pipeline_ops_base.py @@ -217,6 +217,17 @@ def on_worker_op_complete(op, error): return worker_op +class InitializePipelineOperation(PipelineOperation): + """ + A PipelineOperation for doing initial setup of the pipeline + + Attributes can be dynamically added to this operation for use in other stages if necessary + (e.g. initialization requires a derived value) + """ + + pass + + class ConnectOperation(PipelineOperation): """ A PipelineOperation object which tells the pipeline to connect to whatever service it needs to connect to. @@ -318,31 +329,6 @@ def __init__(self, feature_name, callback): self.feature_name = feature_name -class UpdateSasTokenOperation(PipelineOperation): - """ - A PipelineOperation object which contains a SAS token used for connecting. This operation was likely initiated - by a pipeline stage that knows how to generate SAS tokens. - - This operation is in the group of base operations because many different clients use the concept of a SAS token. - - Even though this is an base operation, it will most likely be generated and also handled by more specifics stages - (such as IoTHub or MQTT stages). - """ - - def __init__(self, sas_token, callback): - """ - Initializer for UpdateSasTokenOperation objects. - - :param str sas_token: The token string which will be used to authenticate with whatever - service this pipeline connects with. - :param Function callback: The function that gets called when this operation is complete or has - failed. The callback function must accept A PipelineOperation object which indicates - the specific operation which has completed or failed. - """ - super(UpdateSasTokenOperation, self).__init__(callback=callback) - self.sas_token = sas_token - - class RequestAndResponseOperation(PipelineOperation): """ A PipelineOperation object which wraps the common operation of sending a request to iothub with a request_id ($rid) diff --git a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_ops_http.py b/azure-iot-device/azure/iot/device/common/pipeline/pipeline_ops_http.py index 44d3e0176..47b576232 100644 --- a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_ops_http.py +++ b/azure-iot-device/azure/iot/device/common/pipeline/pipeline_ops_http.py @@ -6,35 +6,6 @@ from . import PipelineOperation -class SetHTTPConnectionArgsOperation(PipelineOperation): - """ - A PipelineOperation object which contains arguments used to connect to a server using the HTTP protocol. - - This operation is in the group of HTTP operations because its attributes are very specific to the HTTP protocol. - """ - - def __init__( - self, hostname, callback, server_verification_cert=None, client_cert=None, sas_token=None - ): - """ - Initializer for SetHTTPConnectionArgsOperation objects. - :param str hostname: The hostname of the HTTP server we will eventually connect to - :param str server_verification_cert: (Optional) The server verification certificate to use - if the HTTP server that we're going to connect to uses server-side TLS - :param X509 client_cert: (Optional) The x509 object containing a client certificate and key used to connect - to the HTTP service - :param str sas_token: The token string which will be used to authenticate with the service - :param Function callback: The function that gets called when this operation is complete or has failed. - The callback function must accept A PipelineOperation object which indicates the specific operation which - has completed or failed. - """ - super(SetHTTPConnectionArgsOperation, self).__init__(callback=callback) - self.hostname = hostname - self.server_verification_cert = server_verification_cert - self.client_cert = client_cert - self.sas_token = sas_token - - class HTTPRequestAndResponseOperation(PipelineOperation): """ A PipelineOperation object which contains arguments used to connect to a server using the HTTP protocol. diff --git a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_ops_mqtt.py b/azure-iot-device/azure/iot/device/common/pipeline/pipeline_ops_mqtt.py index 651c95dc9..6f2f314fb 100644 --- a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_ops_mqtt.py +++ b/azure-iot-device/azure/iot/device/common/pipeline/pipeline_ops_mqtt.py @@ -6,47 +6,6 @@ from . import PipelineOperation -class SetMQTTConnectionArgsOperation(PipelineOperation): - """ - A PipelineOperation object which contains arguments used to connect to a server using the MQTT protocol. - - This operation is in the group of MQTT operations because its attributes are very specific to the MQTT protocol. - """ - - def __init__( - self, - client_id, - hostname, - username, - callback, - server_verification_cert=None, - client_cert=None, - sas_token=None, - ): - """ - Initializer for SetMQTTConnectionArgsOperation objects. - - :param str client_id: The client identifier to use when connecting to the MQTT server - :param str hostname: The hostname of the MQTT server we will eventually connect to - :param str username: The username to use when connecting to the MQTT server - :param str server_verification_cert: (Optional) The server verification certificate to use - if the MQTT server that we're going to connect to uses server-side TLS - :param X509 client_cert: (Optional) The x509 object containing a client certificate and key used to connect - to the MQTT service - :param str sas_token: The token string which will be used to authenticate with the service - :param Function callback: The function that gets called when this operation is complete or has failed. - The callback function must accept A PipelineOperation object which indicates the specific operation which - has completed or failed. - """ - super(SetMQTTConnectionArgsOperation, self).__init__(callback=callback) - self.client_id = client_id - self.hostname = hostname - self.username = username - self.server_verification_cert = server_verification_cert - self.client_cert = client_cert - self.sas_token = sas_token - - class MQTTPublishOperation(PipelineOperation): """ A PipelineOperation object which contains arguments used to publish a specific payload on a specific topic using the MQTT protocol. diff --git a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_stages_base.py b/azure-iot-device/azure/iot/device/common/pipeline/pipeline_stages_base.py index a8ab98793..9017bf8e1 100644 --- a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_stages_base.py +++ b/azure-iot-device/azure/iot/device/common/pipeline/pipeline_stages_base.py @@ -1,4 +1,4 @@ -# -------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. @@ -12,8 +12,8 @@ import traceback import uuid import weakref -from six.moves import queue import threading +from six.moves import queue from . import pipeline_events_base from . import pipeline_ops_base, pipeline_ops_mqtt from . import pipeline_thread @@ -273,6 +273,109 @@ def _handle_pipeline_event(self, event): logger.warning("incoming pipeline event with no handler. dropping.") +# NOTE: This stage could be a candidate for being refactored into some kind of other +# pipeline-related structure. What's odd about it as a stage is that it doesn't really respond +# to operations or events so much as it spawns them on a timer. +# Perhaps some kind of... Pipeline Daemon? +class SasTokenRenewalStage(PipelineStage): + # Amount of time, in seconds, prior to token expiration, when the renewal process will begin + DEFAULT_TOKEN_RENEWAL_MARGIN = 120 + + def __init__(self): + super(SasTokenRenewalStage, self).__init__() + self._token_renewal_timer = None + + @pipeline_thread.runs_on_pipeline_thread + def _run_op(self, op): + if ( + isinstance(op, pipeline_ops_base.InitializePipelineOperation) + and self.pipeline_root.pipeline_configuration.sastoken + ): + self._start_renewal_timer() + self.send_op_down(op) + else: + self.send_op_down(op) + + @pipeline_thread.runs_on_pipeline_thread + def _cancel_token_renewal_timer(self): + """Cancel and delete any pending renewal timer""" + timer = self._token_renewal_timer + self._token_renewal_timer = None + if timer: + logger.debug("Cancelling SAS Token renewal timer") + timer.cancel() + + @pipeline_thread.runs_on_pipeline_thread + def _start_renewal_timer(self): + """Begin a renewal timer. + When the timer expires, and the token is renewed, a new timer will be set""" + self._cancel_token_renewal_timer() + # NOTE: The assumption here is that the SasToken has just been refreshed, so there is + # approximately 'TTL' seconds until expiration. In practice this could probably be off + # a few seconds given processing time. We could make this more accurate if SasToken + # objects had a live TTL value rather than a static one (i.e. "there are n seconds + # remaining in the lifespan of this token", rather than "this token was intended to live + # for n seconds") + seconds_until_renewal = ( + self.pipeline_root.pipeline_configuration.sastoken.ttl + - self.DEFAULT_TOKEN_RENEWAL_MARGIN + ) + if seconds_until_renewal < 0: + # This shouldn't happen in correct flow, but it's possible I suppose, if something goes + # horribly awry elsewhere in the stack, or if we start allowing for custom + # SasToken TTLs or custom Renewal Margins in the future + logger.error("ERROR: SasToken TTL less than Renewal Margin") + handle_exceptions.handle_background_exception( + pipeline_exceptions.PipelineError("SasToken TTL less than Renewal Margin!") + ) + else: + logger.debug( + "Scheduling SAS Token renewal for {} seconds in the future".format( + seconds_until_renewal + ) + ) + self_weakref = weakref.ref(self) + + @pipeline_thread.runs_on_pipeline_thread + def on_reauthorize_complete(op, error): + this = self_weakref() + if error: + logger.error( + "{}({}): reauthorize connection operation failed. Error={}".format( + this.name, op.name, error + ) + ) + handle_exceptions.handle_background_exception(error) + else: + logger.debug( + "{}({}): reauthorize connection operation is complete".format( + this.name, op.name + ) + ) + + @pipeline_thread.invoke_on_pipeline_thread_nowait + def renew_token(): + this = self_weakref() + logger.debug("Renewing SAS Token") + # Renew the token + sastoken = this.pipeline_root.pipeline_configuration.sastoken + sastoken.refresh() + # If the pipeline is already connected, send order to reauthorize the connection + # now that token has been renewed + if this.pipeline_root.connected: + this.send_op_down( + pipeline_ops_base.ReauthorizeConnectionOperation( + callback=on_reauthorize_complete + ) + ) + # Once again, start a renewal timer + this._start_renewal_timer() + + self._token_renewal_timer = threading.Timer(seconds_until_renewal, renew_token) + self._token_renewal_timer.daemon = True + self._token_renewal_timer.start() + + class AutoConnectStage(PipelineStage): """ This stage is responsible for ensuring that the protocol is connected when diff --git a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_stages_http.py b/azure-iot-device/azure/iot/device/common/pipeline/pipeline_stages_http.py index 24de37f92..b4bf6409c 100644 --- a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_stages_http.py +++ b/azure-iot-device/azure/iot/device/common/pipeline/pipeline_stages_http.py @@ -38,25 +38,37 @@ def __init__(self): @pipeline_thread.runs_on_pipeline_thread def _run_op(self, op): - if isinstance(op, pipeline_ops_http.SetHTTPConnectionArgsOperation): - # pipeline_ops_http.SetHTTPConenctionArgsOperation is used to create the HTTPTransport object and set all of it's properties. + if isinstance(op, pipeline_ops_base.InitializePipelineOperation): + + # If there is a gateway hostname, use that as the hostname for connection, + # rather than the hostname itself + if self.pipeline_root.pipeline_configuration.gateway_hostname: + logger.debug( + "Gateway Hostname Present. Setting Hostname to: {}".format( + self.pipeline_root.pipeline_configuration.gateway_hostname + ) + ) + hostname = self.pipeline_root.pipeline_configuration.gateway_hostname + else: + logger.debug( + "Gateway Hostname not present. Setting Hostname to: {}".format( + self.pipeline_root.pipeline_configuration.hostname + ) + ) + hostname = self.pipeline_root.pipeline_configuration.hostname + + # Create HTTP Transport logger.debug("{}({}): got connection args".format(self.name, op.name)) - self.sas_token = op.sas_token self.transport = HTTPTransport( - hostname=op.hostname, - server_verification_cert=op.server_verification_cert, - x509_cert=op.client_cert, + hostname=hostname, + server_verification_cert=self.pipeline_root.pipeline_configuration.server_verification_cert, + x509_cert=self.pipeline_root.pipeline_configuration.x509, cipher=self.pipeline_root.pipeline_configuration.cipher, ) self.pipeline_root.transport = self.transport op.complete() - elif isinstance(op, pipeline_ops_base.UpdateSasTokenOperation): - logger.debug("{}({}): saving sas token and completing".format(self.name, op.name)) - self.sas_token = op.sas_token - op.complete() - elif isinstance(op, pipeline_ops_http.HTTPRequestAndResponseOperation): # This will call down to the HTTP Transport with a request and also created a request callback. Because the HTTP Transport will run on the http transport thread, this call should be non-blocking to the pipline thread. logger.debug( @@ -85,10 +97,14 @@ def on_request_completed(error=None, response=None): op.reason = response["reason"] op.complete() - # A deepcopy is necessary here since otherwise the manipulation happening to http_headers will affect the op.headers, which would be an unintended side effect and not a good practice. + # A deepcopy is necessary here since otherwise the manipulation happening to + # http_headers will affect the op.headers, which would be an unintended side effect + # and not a good practice. http_headers = copy.deepcopy(op.headers) - if self.sas_token: - http_headers["Authorization"] = self.sas_token + if self.pipeline_root.pipeline_configuration.sastoken: + http_headers["Authorization"] = str( + self.pipeline_root.pipeline_configuration.sastoken + ) self.transport.request( method=op.method, diff --git a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_stages_mqtt.py b/azure-iot-device/azure/iot/device/common/pipeline/pipeline_stages_mqtt.py index 814ef9605..7df1e8cb8 100644 --- a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_stages_mqtt.py +++ b/azure-iot-device/azure/iot/device/common/pipeline/pipeline_stages_mqtt.py @@ -38,9 +38,6 @@ class MQTTTransportStage(PipelineStage): def __init__(self): super(MQTTTransportStage, self).__init__() - # The sas_token will be set when Connetion Args are received - self.sas_token = None - # The transport will be instantiated when Connection Args are received self.transport = None @@ -111,17 +108,33 @@ def _cancel_connection_watchdog(self, op): @pipeline_thread.runs_on_pipeline_thread def _run_op(self, op): - if isinstance(op, pipeline_ops_mqtt.SetMQTTConnectionArgsOperation): - # pipeline_ops_mqtt.SetMQTTConnectionArgsOperation is where we create our MQTTTransport object and set - # all of its properties. + if isinstance(op, pipeline_ops_base.InitializePipelineOperation): + + # If there is a gateway hostname, use that as the hostname for connection, + # rather than the hostname itself + if self.pipeline_root.pipeline_configuration.gateway_hostname: + logger.debug( + "Gateway Hostname Present. Setting Hostname to: {}".format( + self.pipeline_root.pipeline_configuration.gateway_hostname + ) + ) + hostname = self.pipeline_root.pipeline_configuration.gateway_hostname + else: + logger.debug( + "Gateway Hostname not present. Setting Hostname to: {}".format( + self.pipeline_root.pipeline_configuration.hostname + ) + ) + hostname = self.pipeline_root.pipeline_configuration.hostname + + # Create the Transport object, set it's handlers logger.debug("{}({}): got connection args".format(self.name, op.name)) - self.sas_token = op.sas_token self.transport = MQTTTransport( client_id=op.client_id, - hostname=op.hostname, + hostname=hostname, username=op.username, - server_verification_cert=op.server_verification_cert, - x509_cert=op.client_cert, + server_verification_cert=self.pipeline_root.pipeline_configuration.server_verification_cert, + x509_cert=self.pipeline_root.pipeline_configuration.x509, websockets=self.pipeline_root.pipeline_configuration.websockets, cipher=self.pipeline_root.pipeline_configuration.cipher, proxy_options=self.pipeline_root.pipeline_configuration.proxy_options, @@ -153,19 +166,20 @@ def _run_op(self, op): op.complete() - elif isinstance(op, pipeline_ops_base.UpdateSasTokenOperation): - logger.debug("{}({}): saving sas token and completing".format(self.name, op.name)) - self.sas_token = op.sas_token - op.complete() - elif isinstance(op, pipeline_ops_base.ConnectOperation): logger.info("{}({}): connecting".format(self.name, op.name)) self._cancel_pending_connection_op() self._pending_connection_op = op self._start_connection_watchdog(op) + # Use SasToken as password if present. If not present (e.g. using X509), + # then no password is required because auth is handled via other means. + if self.pipeline_root.pipeline_configuration.sastoken: + password = str(self.pipeline_root.pipeline_configuration.sastoken) + else: + password = None try: - self.transport.connect(password=self.sas_token) + self.transport.connect(password=password) except Exception as e: logger.error("transport.connect raised error") logger.error(traceback.format_exc()) @@ -181,7 +195,9 @@ def _run_op(self, op): self._pending_connection_op = op self._start_connection_watchdog(op) try: - self.transport.reauthorize_connection(password=self.sas_token) + self.transport.reauthorize_connection( + password=str(self.pipeline_root.pipeline_configuration.sastoken) + ) except Exception as e: logger.error("transport.reauthorize_connection raised error") logger.error(traceback.format_exc()) diff --git a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_thread.py b/azure-iot-device/azure/iot/device/common/pipeline/pipeline_thread.py index d50611d77..269129f25 100644 --- a/azure-iot-device/azure/iot/device/common/pipeline/pipeline_thread.py +++ b/azure-iot-device/azure/iot/device/common/pipeline/pipeline_thread.py @@ -191,7 +191,7 @@ def wrapper(*args, **kwargs): It should be. You should use invoke_on_{thread_name}_thread(_nowait) to enter the {thread_name} thread before calling this function. If you're hitting this from inside a test function, you may need to add the fake_pipeline_thread fixture to - your test. (grep for apply_fake_pipeline_thread) """.format( + your test. (generally applied on the global pytestmark in a module) """.format( function_name=func.__name__, thread_name=thread_name ) diff --git a/azure-iot-device/azure/iot/device/common/sastoken.py b/azure-iot-device/azure/iot/device/common/sastoken.py deleted file mode 100644 index a6a21159c..000000000 --- a/azure-iot-device/azure/iot/device/common/sastoken.py +++ /dev/null @@ -1,85 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -"""This module contains tools for working with Shared Access Signature (SAS) Tokens""" - -import base64 -import hmac -import hashlib -import time -import six.moves.urllib as urllib - - -class SasTokenError(Exception): - """Error in SasToken""" - - def __init__(self, message, cause=None): - """Initializer for SasTokenError - - :param str message: Error message - :param cause: Exception that caused this error (optional) - """ - super(SasTokenError, self).__init__(message) - self.cause = cause - - -class SasToken(object): - """Shared Access Signature Token used to authenticate a request - - Parameters: - uri (str): URI of the resouce to be accessed - key_name (str): Shared Access Key Name - key (str): Shared Access Key (base64 encoded) - ttl (int)[default 3600]: Time to live for the token, in seconds - - Data Attributes: - expiry_time (int): Time that token will expire (in UTC, since epoch) - ttl (int): Time to live for the token, in seconds - - Raises: - SasTokenError if trying to build a SasToken from invalid values - """ - - _encoding_type = "utf-8" - _service_token_format = "SharedAccessSignature sr={}&sig={}&se={}&skn={}" - _device_token_format = "SharedAccessSignature sr={}&sig={}&se={}" - - def __init__(self, uri, key, key_name=None, ttl=3600): - self._uri = urllib.parse.quote_plus(uri) - self._key = key - self._key_name = key_name - self.ttl = ttl - self.refresh() - - def __str__(self): - return self._token - - def refresh(self): - """ - Refresh the SasToken lifespan, giving it a new expiry time - """ - self.expiry_time = int(time.time() + self.ttl) - self._token = self._build_token() - - def _build_token(self): - """Buid SasToken representation - - Returns: - String representation of the token - """ - try: - message = (self._uri + "\n" + str(self.expiry_time)).encode(self._encoding_type) - signing_key = base64.b64decode(self._key.encode(self._encoding_type)) - signed_hmac = hmac.HMAC(signing_key, message, hashlib.sha256) - signature = urllib.parse.quote(base64.b64encode(signed_hmac.digest())) - except (TypeError, base64.binascii.Error) as e: - raise SasTokenError("Unable to build SasToken from given values", e) - if self._key_name: - token = self._service_token_format.format( - self._uri, signature, str(self.expiry_time), self._key_name - ) - else: - token = self._device_token_format.format(self._uri, signature, str(self.expiry_time)) - return token diff --git a/azure-iot-device/azure/iot/device/iothub/abstract_clients.py b/azure-iot-device/azure/iot/device/iothub/abstract_clients.py index 68162c4c2..34dca98c3 100644 --- a/azure-iot-device/azure/iot/device/iothub/abstract_clients.py +++ b/azure-iot-device/azure/iot/device/iothub/abstract_clients.py @@ -11,20 +11,14 @@ import logging import os import io -from . import auth from . import pipeline +from azure.iot.device.common.auth import connection_string as cs +from azure.iot.device.common.auth import sastoken as st -logger = logging.getLogger(__name__) +from azure.iot.device.common import auth +from . import edge_hsm -# A note on implementation: -# The intializer methods accept pipeline(s) instead of an auth provider in order to protect -# the client from logic related to authentication providers. This reduces edge cases, and allows -# pipeline configuration to be specifically tailored to the method of instantiation. -# For instance, .create_from_connection_string and .create_from_edge_envrionment both can use -# SymmetricKeyAuthenticationProviders to instantiate pipeline(s), but only .create_from_edge_environment -# should use it to instantiate an HTTPPipeline. If the initializer accepted an auth provider, and then -# used it to create pipelines, this detail would be lost, as there would be no way to tell if a -# SymmetricKeyAuthenticationProvider was intended to be part of an Edge scenario or not. +logger = logging.getLogger(__name__) def _validate_kwargs(**kwargs): @@ -43,18 +37,13 @@ def _validate_kwargs(**kwargs): raise TypeError("Got an unexpected keyword argument '{}'".format(kwarg)) -def _get_pipeline_config_kwargs(**kwargs): - """Helper function to get a subset of user provided kwargs relevant to IoTHubPipelineConfig""" - new_kwargs = {} - if "product_info" in kwargs: - new_kwargs["product_info"] = kwargs["product_info"] - if "websockets" in kwargs: - new_kwargs["websockets"] = kwargs["websockets"] - if "cipher" in kwargs: - new_kwargs["cipher"] = kwargs["cipher"] - if "proxy_options" in kwargs: - new_kwargs["proxy_options"] = kwargs["proxy_options"] - return new_kwargs +def _form_sas_uri(hostname, device_id, module_id=None): + if module_id: + return "{hostname}/devices/{device_id}/modules/{module_id}".format( + hostname=hostname, device_id=device_id, module_id=module_id + ) + else: + return "{hostname}/devices/{device_id}".format(hostname=hostname, device_id=device_id) @six.add_metaclass(abc.ABCMeta) @@ -98,24 +87,41 @@ def create_from_connection_string(cls, connection_string, **kwargs): :returns: An instance of an IoTHub client that uses a connection string for authentication. """ # TODO: Make this device/module specific and reject non-matching connection strings. - # This will require refactoring of the auth package to use common objects (e.g. ConnectionString) - # in order to differentiate types of connection strings. + # Ensure no invalid kwargs were passed by the user _validate_kwargs(**kwargs) + # Create SasToken + connection_string = cs.ConnectionString(connection_string) + uri = _form_sas_uri( + hostname=connection_string[cs.HOST_NAME], + device_id=connection_string[cs.DEVICE_ID], + module_id=connection_string.get(cs.MODULE_ID), + ) + signing_mechanism = auth.SymmetricKeySigningMechanism( + key=connection_string[cs.SHARED_ACCESS_KEY] + ) + try: + sastoken = st.SasToken(uri, signing_mechanism) + except st.SasTokenError as e: + new_err = ValueError("Could not create a SasToken using provided connection string") + new_err.__cause__ = e + raise new_err # Pipeline Config setup - pipeline_config_kwargs = _get_pipeline_config_kwargs(**kwargs) - pipeline_configuration = pipeline.IoTHubPipelineConfig(**pipeline_config_kwargs) + pipeline_configuration = pipeline.IoTHubPipelineConfig( + device_id=connection_string[cs.DEVICE_ID], + module_id=connection_string.get(cs.MODULE_ID), + hostname=connection_string[cs.HOST_NAME], + gateway_hostname=connection_string.get(cs.GATEWAY_HOST_NAME), + sastoken=sastoken, + **kwargs + ) if cls.__name__ == "IoTHubDeviceClient": pipeline_configuration.blob_upload = True - # Auth Provider setup - authentication_provider = auth.SymmetricKeyAuthenticationProvider.parse(connection_string) - authentication_provider.server_verification_cert = kwargs.get("server_verification_cert") - # Pipeline setup - http_pipeline = pipeline.HTTPPipeline(authentication_provider, pipeline_configuration) - mqtt_pipeline = pipeline.MQTTPipeline(authentication_provider, pipeline_configuration) + http_pipeline = pipeline.HTTPPipeline(pipeline_configuration) + mqtt_pipeline = pipeline.MQTTPipeline(pipeline_configuration) return cls(mqtt_pipeline, http_pipeline) @@ -192,22 +198,18 @@ def create_from_x509_certificate(cls, x509, hostname, device_id, **kwargs): :returns: An instance of an IoTHub client that uses an X509 certificate for authentication. """ + # Ensure no invalid kwargs were passed by the user _validate_kwargs(**kwargs) # Pipeline Config setup - pipeline_config_kwargs = _get_pipeline_config_kwargs(**kwargs) - pipeline_configuration = pipeline.IoTHubPipelineConfig(**pipeline_config_kwargs) - pipeline_configuration.blob_upload = True # Blob Upload is a feature on Device Clients - - # Auth Provider setup - authentication_provider = auth.X509AuthenticationProvider( - x509=x509, hostname=hostname, device_id=device_id + pipeline_configuration = pipeline.IoTHubPipelineConfig( + device_id=device_id, hostname=hostname, x509=x509, **kwargs ) - authentication_provider.server_verification_cert = kwargs.get("server_verification_cert") + pipeline_configuration.blob_upload = True # Blob Upload is a feature on Device Clients # Pipeline setup - http_pipeline = pipeline.HTTPPipeline(authentication_provider, pipeline_configuration) - mqtt_pipeline = pipeline.MQTTPipeline(authentication_provider, pipeline_configuration) + http_pipeline = pipeline.HTTPPipeline(pipeline_configuration) + mqtt_pipeline = pipeline.MQTTPipeline(pipeline_configuration) return cls(mqtt_pipeline, http_pipeline) @@ -238,22 +240,28 @@ def create_from_symmetric_key(cls, symmetric_key, hostname, device_id, **kwargs) :return: An instance of an IoTHub client that uses a symmetric key for authentication. """ + # Ensure no invalid kwargs were passed by the user _validate_kwargs(**kwargs) - # Pipeline Config setup - pipeline_config_kwargs = _get_pipeline_config_kwargs(**kwargs) - pipeline_configuration = pipeline.IoTHubPipelineConfig(**pipeline_config_kwargs) - pipeline_configuration.blob_upload = True # Blob Upload is a feature on Device Clients + # Create SasToken + uri = _form_sas_uri(hostname=hostname, device_id=device_id) + signing_mechanism = auth.SymmetricKeySigningMechanism(key=symmetric_key) + try: + sastoken = st.SasToken(uri, signing_mechanism) + except st.SasTokenError as e: + new_err = ValueError("Could not create a SasToken using provided values") + new_err.__cause__ = e + raise new_err - # Auth Provider setup - authentication_provider = auth.SymmetricKeyAuthenticationProvider( - hostname=hostname, device_id=device_id, module_id=None, shared_access_key=symmetric_key + # Pipeline Config setup + pipeline_configuration = pipeline.IoTHubPipelineConfig( + device_id=device_id, hostname=hostname, sastoken=sastoken, **kwargs ) - authentication_provider.server_verification_cert = kwargs.get("server_verification_cert") + pipeline_configuration.blob_upload = True # Blob Upload is a feature on Device Clients # Pipeline setup - http_pipeline = pipeline.HTTPPipeline(authentication_provider, pipeline_configuration) - mqtt_pipeline = pipeline.MQTTPipeline(authentication_provider, pipeline_configuration) + http_pipeline = pipeline.HTTPPipeline(pipeline_configuration) + mqtt_pipeline = pipeline.MQTTPipeline(pipeline_configuration) return cls(mqtt_pipeline, http_pipeline) @@ -296,6 +304,7 @@ def create_from_edge_environment(cls, **kwargs): :returns: An instance of an IoTHub client that uses the IoT Edge environment for authentication. """ + # Ensure no invalid kwargs were passed by the user _validate_kwargs(**kwargs) if kwargs.get("server_verification_cert"): raise TypeError( @@ -322,8 +331,9 @@ def create_from_edge_environment(cls, **kwargs): new_err = OSError("IoT Edge environment not configured correctly") new_err.__cause__ = e raise new_err - # TODO: variant server_verification_cert file vs data object that would remove the need for this fopen + # Read the certificate file to pass it on as a string + # TODO: variant server_verification_cert file vs data object that would remove the need for this fopen try: with io.open(ca_cert_filepath, mode="r") as ca_cert_file: server_verification_cert = ca_cert_file.read() @@ -338,41 +348,66 @@ def create_from_edge_environment(cls, **kwargs): new_err = ValueError("Invalid CA certificate file") new_err.__cause__ = e raise new_err - # Use Symmetric Key authentication for local dev experience. + + # Extract config values from connection string + connection_string = cs.ConnectionString(connection_string) try: - authentication_provider = auth.SymmetricKeyAuthenticationProvider.parse( - connection_string - ) - except ValueError: - raise - authentication_provider.server_verification_cert = server_verification_cert + device_id = connection_string[cs.DEVICE_ID] + module_id = connection_string[cs.MODULE_ID] + hostname = connection_string[cs.HOST_NAME] + gateway_hostname = connection_string[cs.GATEWAY_HOST_NAME] + except KeyError: + raise ValueError("Invalid Connection String") + + # Use Symmetric Key authentication for local dev experience. + signing_mechanism = auth.SymmetricKeySigningMechanism( + key=connection_string[cs.SHARED_ACCESS_KEY] + ) + else: # Use an HSM for authentication in the general case + hsm = edge_hsm.IoTEdgeHsm( + module_id=module_id, + generation_id=module_generation_id, + workload_uri=workload_uri, + api_version=api_version, + ) try: - authentication_provider = auth.IoTEdgeAuthenticationProvider( - hostname=hostname, - device_id=device_id, - module_id=module_id, - gateway_hostname=gateway_hostname, - module_generation_id=module_generation_id, - workload_uri=workload_uri, - api_version=api_version, - ) - except auth.IoTEdgeError as e: + server_verification_cert = hsm.get_certificate() + except edge_hsm.IoTEdgeError as e: new_err = OSError("Unexpected failure in IoTEdge") new_err.__cause__ = e raise new_err + signing_mechanism = hsm + + # Create SasToken + uri = _form_sas_uri(hostname=hostname, device_id=device_id, module_id=module_id) + try: + sastoken = st.SasToken(uri, signing_mechanism) + except st.SasTokenError as e: + new_err = ValueError( + "Could not create a SasToken using the values in the Edge environment" + ) + new_err.__cause__ = e + raise new_err # Pipeline Config setup - pipeline_config_kwargs = _get_pipeline_config_kwargs(**kwargs) - pipeline_configuration = pipeline.IoTHubPipelineConfig(**pipeline_config_kwargs) + pipeline_configuration = pipeline.IoTHubPipelineConfig( + device_id=device_id, + module_id=module_id, + hostname=hostname, + gateway_hostname=gateway_hostname, + sastoken=sastoken, + server_verification_cert=server_verification_cert, + **kwargs + ) pipeline_configuration.method_invoke = ( True ) # Method Invoke is allowed on modules created from edge environment # Pipeline setup - http_pipeline = pipeline.HTTPPipeline(authentication_provider, pipeline_configuration) - mqtt_pipeline = pipeline.MQTTPipeline(authentication_provider, pipeline_configuration) + http_pipeline = pipeline.HTTPPipeline(pipeline_configuration) + mqtt_pipeline = pipeline.MQTTPipeline(pipeline_configuration) return cls(mqtt_pipeline, http_pipeline) @@ -408,21 +443,17 @@ def create_from_x509_certificate(cls, x509, hostname, device_id, module_id, **kw :returns: An instance of an IoTHub client that uses an X509 certificate for authentication. """ + # Ensure no invalid kwargs were passed by the user _validate_kwargs(**kwargs) # Pipeline Config setup - pipeline_config_kwargs = _get_pipeline_config_kwargs(**kwargs) - pipeline_configuration = pipeline.IoTHubPipelineConfig(**pipeline_config_kwargs) - - # Auth Provider setup - authentication_provider = auth.X509AuthenticationProvider( - x509=x509, hostname=hostname, device_id=device_id, module_id=module_id + pipeline_configuration = pipeline.IoTHubPipelineConfig( + device_id=device_id, module_id=module_id, hostname=hostname, x509=x509, **kwargs ) - authentication_provider.server_verification_cert = kwargs.get("server_verification_cert") # Pipeline setup - http_pipeline = pipeline.HTTPPipeline(authentication_provider, pipeline_configuration) - mqtt_pipeline = pipeline.MQTTPipeline(authentication_provider, pipeline_configuration) + http_pipeline = pipeline.HTTPPipeline(pipeline_configuration) + mqtt_pipeline = pipeline.MQTTPipeline(pipeline_configuration) return cls(mqtt_pipeline, http_pipeline) @abc.abstractmethod diff --git a/azure-iot-device/azure/iot/device/iothub/aio/async_clients.py b/azure-iot-device/azure/iot/device/iothub/aio/async_clients.py index e9b40a342..be7ce625c 100644 --- a/azure-iot-device/azure/iot/device/iothub/aio/async_clients.py +++ b/azure-iot-device/azure/iot/device/iothub/aio/async_clients.py @@ -437,10 +437,12 @@ async def send_message_to_output(self, message, output_name): message.output_name = output_name logger.info("Sending message to output:" + output_name + "...") - send_output_event_async = async_adapter.emulate_async(self._mqtt_pipeline.send_output_event) + send_output_message_async = async_adapter.emulate_async( + self._mqtt_pipeline.send_output_message + ) callback = async_adapter.AwaitableCallback() - await send_output_event_async(message, callback=callback) + await send_output_message_async(message, callback=callback) await handle_result(callback) logger.info("Successfully sent message to output: " + output_name) diff --git a/azure-iot-device/azure/iot/device/iothub/auth/__init__.py b/azure-iot-device/azure/iot/device/iothub/auth/__init__.py deleted file mode 100644 index 967326d50..000000000 --- a/azure-iot-device/azure/iot/device/iothub/auth/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -"""Azure IoT Hub Device SDK Authentication - -This package provides authentication-related functionality for use with the -Azure IoT Hub Device SDK. -""" - -from .sk_authentication_provider import SymmetricKeyAuthenticationProvider -from .sas_authentication_provider import SharedAccessSignatureAuthenticationProvider -from .iotedge_authentication_provider import IoTEdgeAuthenticationProvider, IoTEdgeError -from .x509_authentication_provider import X509AuthenticationProvider - -__all__ = [ - "SymmetricKeyAuthenticationProvider", - "SharedAccessSignatureAuthenticationProvider", - "IoTEdgeAuthenticationProvider", - "IoTEdgeError", - "X509AuthenticationProvider", -] diff --git a/azure-iot-device/azure/iot/device/iothub/auth/authentication_provider.py b/azure-iot-device/azure/iot/device/iothub/auth/authentication_provider.py deleted file mode 100644 index e98df731a..000000000 --- a/azure-iot-device/azure/iot/device/iothub/auth/authentication_provider.py +++ /dev/null @@ -1,34 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import abc -import six - - -@six.add_metaclass(abc.ABCMeta) -class AuthenticationProvider(object): - """Super class for all providing known types of authentication mechanism like - x509 and SAS based authentication. - - :ivar str hostname: Hostname - :ivar str device_id: Device ID - :ivar str module_id: Module ID - """ - - def __init__(self, hostname, device_id, module_id=None): - """Initializer for AuthenticationProvider - - :param str hostname: Hostname - :param str device_id: Device ID - :param str module_id: Module ID (optional) - """ - self.hostname = hostname - self.device_id = device_id - self.module_id = module_id - - -# TODO: Potentially some additional abstract class that defines an abstract .get_current_sas_token() -# in order to enforce sas token retrieval in various sas-affiliated auths (sas, sk)? diff --git a/azure-iot-device/azure/iot/device/iothub/auth/base_renewable_token_authentication_provider.py b/azure-iot-device/azure/iot/device/iothub/auth/base_renewable_token_authentication_provider.py deleted file mode 100644 index ba8eaec9a..000000000 --- a/azure-iot-device/azure/iot/device/iothub/auth/base_renewable_token_authentication_provider.py +++ /dev/null @@ -1,232 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -"""This module provides a base class for renewable token authentication providers""" - -import time -import abc -import logging -import math -import six -import weakref -from threading import Timer -import six.moves.urllib as urllib -from .authentication_provider import AuthenticationProvider - -logger = logging.getLogger(__name__) - -_device_keyname_token_format = "SharedAccessSignature sr={}&sig={}&se={}&skn={}" -_device_token_format = "SharedAccessSignature sr={}&sig={}&se={}" - -# Length of time, in seconds, that a SAS token is valid for. -DEFAULT_TOKEN_VALIDITY_PERIOD = 3600 - -# Length of time, in seconds, before a token expires that we want to begin renewing it. -DEFAULT_TOKEN_RENEWAL_MARGIN = 120 - - -@six.add_metaclass(abc.ABCMeta) -class BaseRenewableTokenAuthenticationProvider(AuthenticationProvider): - """A base class for authentication providers which are based on SAS (Shared - Authentication Signature) strings which are able to be renewed. - - The SAS token string renewal is based on a signing function that is used - to create the sig field of the SAS string. This base implements all - functionality for SAS string creation except for the signing function, - which is expected to be provided by derived objects. This base also - implements the functionality necessary for timing and executing the - token renewal operation. - """ - - def __init__(self, hostname, device_id, module_id=None): - """Initializer for Renewable Token Authentication Provider. - - This object is intended as a base class and cannot be used directly. - A derived class which provides a signing function (such as - SymmetricKeyAuthenticationProvider or IoTEdgeAuthenticationProvider) - should be used instead. - - :param str hostname: The hostname - :param str device_id: The device ID - :param str module_id: The module ID (optional) - """ - - super(BaseRenewableTokenAuthenticationProvider, self).__init__( - hostname=hostname, device_id=device_id, module_id=module_id - ) - self.token_validity_period = DEFAULT_TOKEN_VALIDITY_PERIOD - self.token_renewal_margin = DEFAULT_TOKEN_RENEWAL_MARGIN - self._token_update_timer = None - self.shared_access_key_name = None - self.sas_token_str = None - self.on_sas_token_updated_handler_list = [] - - def __del__(self): - self._cancel_token_update_timer() - - def generate_new_sas_token(self): - """Force the SAS token to update itself. - - This will cause a new sas token to be created using the _sign function. - This token is valid for roughly self.token_validity_period second. - - This validity period can only be roughly enforced because it relies on the - coordination of clocks between the client device and the service. If the two - different machines have different definitions of "now", most likely because - of clock drift, then they will also have different notions of when a token will - expire. This algorithm atempts to compensate for clock drift by taking - self.token_renewal_margin into account when deciding when to renew a token. - - If self.token_udpate_callback is set, this callback will be called to notify the - pipeline that a new token is available. The pipeline is responsible for doing - whatever is necessary to leverage the new token when the on_sas_token_updated_handler_list - function is called. - - The token that is generated expires at some point in the future, based on the token - renewal interval and the token renewal margin. When a token is first generated, the - authorization provider object will set a timer which will be responsible for renewing - the token before the it expires. When this timer fires, it will automatically generate - a new sas token and notify the pipeline by calling self.on_sas_token_updated_handler_list. - - The token update timer is set based on two numbers: self.token_validity_period and - self.token_renewal_margin - - The first number is the validity period. This defines the amount of time that the token - is valid. The interval is encoded in the token as an offset from the current time, - as based on the Unix epoch. In other words, the expiry (se=) value in the token - is the number of seconds after 00:00 on January 1, 1970 that the token expires. - - The second number that defines the token renewal behavior is the margin. This is - the number of seconds before expiration that we want to generate a new token. Since - the clocks on different computers can drift over time, they will all have different - definitions of what "now" is, so the margin needs to be set so there is a - very small chance that there is no time overlap where one computer thinks the token - is expired and another doesn't. - - When the timer is set to renew the SAS token, the timer is set for - (token_validity_period - token_renewal_margin) seconds in the future. In this way, - the token will be renewed close to it's expiration time, but not so close that - we risk a problem caused by clock drift. - - :return: None - """ - logger.info( - "Generating new SAS token for (%s,%s) that expires %d seconds in the future", - self.device_id, - self.module_id, - self.token_validity_period, - ) - expiry = int(math.floor(time.time()) + self.token_validity_period) - resource_uri = self.hostname + "/devices/" + self.device_id - if self.module_id: - resource_uri += "/modules/" + self.module_id - quoted_resource_uri = urllib.parse.quote_plus(resource_uri) - - signature = self._sign(quoted_resource_uri, expiry) - - if self.shared_access_key_name: - token = _device_keyname_token_format.format( - quoted_resource_uri, signature, str(expiry), self.shared_access_key_name - ) - else: - token = _device_token_format.format(quoted_resource_uri, signature, str(expiry)) - - self.sas_token_str = str(token) - self._schedule_token_update(self.token_validity_period - self.token_renewal_margin) - self._notify_token_updated() - - def _cancel_token_update_timer(self): - """Cancel any future token update operations. This is typically done as part of a - teardown operation. - """ - t = self._token_update_timer - self._token_update_timer = None - if t: - logger.debug( - "Canceling token update timer for (%s,%s)", - self.device_id, - self.module_id if self.module_id else "", - ) - t.cancel() - - def _schedule_token_update(self, seconds_until_update): - """Schedule an automatic sas token update to take place seconds_until_update seconds in - the future. If an update was previously scheduled, this method shall cancel the - previously-scheduled update and schedule a new update. - """ - self._cancel_token_update_timer() - logger.debug( - "Scheduling token update for (%s,%s) for %d seconds in the future", - self.device_id, - self.module_id, - seconds_until_update, - ) - - # It's important to use a weak reference to self inside this timer function - # because we don't want the timer to prevent this object (`self`) from being collected. - # - # We want `self` to get collected when the pipeline gets collected, and - # we want the pipeline to get collected when the client object gets collected. - # This way, everything gets cleaned up when the user is done with the client object, - # as expected. - # - # If timerfunc used `self` directly, that would be a strong reference, and that strong - # reference would prevent `self` from being collected as long as the timer existed. - # - # If this isn't collected when the client is collected, then the object that implements the - # on_sas_token_updated_hndler doesn't get collected. Since that object is part of the - # pipeline, a major part of the pipeline ends up staying around, probably orphaned from - # the client. Since that orphaned part of the pipeline contains Paho, bad things can happen - # if we don't clean up Paho correctly. This is especially noticable if one process - # destroys a client object and creates a new one. - # - self_weakref = weakref.ref(self) - - def timerfunc(): - this = self_weakref() - logger.debug("Timed SAS update for (%s,%s)", this.device_id, this.module_id) - this.generate_new_sas_token() - - self._token_update_timer = Timer(seconds_until_update, timerfunc) - self._token_update_timer.daemon = True - self._token_update_timer.start() - - def _notify_token_updated(self): - """Notify clients that the SAS token has been updated by calling self.on_sas_token_updated. - In response to this event, clients should re-initiate their connection in order to use - the updated sas token. - """ - if bool(len(self.on_sas_token_updated_handler_list)): - logger.debug( - "sending token update notification for (%s, %s)", self.device_id, self.module_id - ) - for x in self.on_sas_token_updated_handler_list: - x() - else: - logger.warning( - "_notify_token_updated: on_sas_token_updated_handler_list not set. Doing nothing." - ) - - def get_current_sas_token(self): - """Get the current SharedAuthenticationSignature string. - - This string can be used to authenticate with an Azure IoT Hub or Azure IoT Edge Hub service. - - If a SAS token has not yet been created yet, this function call the generate_new_sas_token - function to create a new token and schedule the update timer. See the documentation for - generate_new_sas_token for more detail. - - :return: The current shared access signature token in string form. - """ - if not self.sas_token_str: - self.generate_new_sas_token() - return self.sas_token_str - - @abc.abstractmethod - def _sign(self, quoted_resource_uri, expiry): - """Create and return a new signature for this object. The caller is responsible - for placing the signature inside the sig field of a SAS token string. - """ - pass diff --git a/azure-iot-device/azure/iot/device/iothub/auth/sas_authentication_provider.py b/azure-iot-device/azure/iot/device/iothub/auth/sas_authentication_provider.py deleted file mode 100644 index 33d732899..000000000 --- a/azure-iot-device/azure/iot/device/iothub/auth/sas_authentication_provider.py +++ /dev/null @@ -1,121 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import logging -from .authentication_provider import AuthenticationProvider - -""" -The urllib, urllib2, and urlparse modules from Python 2 have been combined in the urllib package in Python 3 -The six.moves.urllib package is a python version-independent location of the above functionality. -""" -import six.moves.urllib as urllib - -logger = logging.getLogger(__name__) - -URI_SEPARATOR = "/" -DELIMITER = "&" -VALUE_SEPARATOR = "=" -PARTS_SEPARATOR = " " - -SIGNATURE = "sig" -SHARED_ACCESS_KEY_NAME = "skn" -RESOURCE_URI = "sr" -EXPIRY = "se" - -_valid_keys = [SIGNATURE, SHARED_ACCESS_KEY_NAME, RESOURCE_URI, EXPIRY] - - -class SharedAccessSignatureAuthenticationProvider(AuthenticationProvider): - """ - The Shared Access Signature Authentication Provider. - This provider already contains the sas token which will be needed to authenticate with The IoT hub. - """ - - def __init__(self, hostname, device_id, module_id, sas_token_str): - """ - Constructor for Shared Access Signature Authentication Provider - """ - logger.info("Using SAS authentication for {%s, %s, %s}", hostname, device_id, module_id) - super(SharedAccessSignatureAuthenticationProvider, self).__init__( - hostname=hostname, device_id=device_id, module_id=module_id - ) - self.sas_token_str = sas_token_str - - def get_current_sas_token(self): - """ - :return: the string representation of the current Shared Access Signature - """ - return self.sas_token_str - - @staticmethod - def parse(sas_token_str): - """ - This method creates a Shared Access Signature Authentication Provider from a string, and sets properties for each of the parsed - fields in the string. Also validates the required properties of the shared access signature. - :param sas_token_str: The ampersand-delimited string of 'name=value' pairs. - The input may look like the following formations:- - SharedAccessSignature sr=&sig=&se= - SharedAccessSignature sr=&sig=&skn=&se= - :return: The Shared Access Signature Authentication Provider constructed - """ - try: - parts = sas_token_str.split(PARTS_SEPARATOR) - sas_args = parts[1].split(DELIMITER) - d = dict(arg.split(VALUE_SEPARATOR, 1) for arg in sas_args) - except (IndexError, ValueError, AttributeError): - raise ValueError( - "The Shared Access Signature is required and should not be empty or blank and must be supplied as a string consisting of two parts in the format 'SharedAccessSignature sr=&sig=&se=' with an optional skn=" - ) - - if len(parts) != 2: - raise ValueError( - "The Shared Access Signature must be of the format 'SharedAccessSignature sr=&sig=&se=' or/and it can additionally contain an optional skn= name=value pair." - ) - - if len(sas_args) != len(d): - raise ValueError("Invalid Shared Access Signature - Unable to parse") - if not all(key in _valid_keys for key in d.keys()): - raise ValueError( - "Invalid keys in Shared Access Signature. The valid keys are sr, sig, se and an optional skn." - ) - - _validate_required_keys(d) - - try: - unquoted_resource_uri = urllib.parse.unquote_plus(d.get(RESOURCE_URI)) - url_segments = unquoted_resource_uri.split(URI_SEPARATOR) - - module_id = None - hostname = url_segments[0] - device_id = url_segments[2] - - if len(url_segments) > 4: - module_id = url_segments[4] - - return SharedAccessSignatureAuthenticationProvider( - hostname, device_id, module_id, sas_token_str - ) - except IndexError: - raise ValueError( - "One of the name value pair of the Shared Access Signature string should be a proper resource uri" - ) - - -def _validate_required_keys(d): - """ - Validates that required keys are present. - Raise ValueError if incorrect combination of keys - """ - resource_uri = d.get(RESOURCE_URI) - signature = d.get(SIGNATURE) - expiry = d.get(EXPIRY) - - if resource_uri and signature and expiry: - pass - else: - raise ValueError( - "Invalid Shared Access Signature. It must be of the format 'SharedAccessSignature sr=&sig=&se=' or/and it can additionally contain an optional skn= name=value pair." - ) diff --git a/azure-iot-device/azure/iot/device/iothub/auth/sk_authentication_provider.py b/azure-iot-device/azure/iot/device/iothub/auth/sk_authentication_provider.py deleted file mode 100644 index d6d309c36..000000000 --- a/azure-iot-device/azure/iot/device/iothub/auth/sk_authentication_provider.py +++ /dev/null @@ -1,135 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import base64 -import hmac -import hashlib -import logging -import six.moves.urllib as urllib -from .base_renewable_token_authentication_provider import BaseRenewableTokenAuthenticationProvider - -logger = logging.getLogger(__name__) - -DELIMITER = ";" -VALUE_SEPARATOR = "=" - -HOST_NAME = "HostName" -SHARED_ACCESS_KEY_NAME = "SharedAccessKeyName" -SHARED_ACCESS_KEY = "SharedAccessKey" -SHARED_ACCESS_SIGNATURE = "SharedAccessSignature" -DEVICE_ID = "DeviceId" -MODULE_ID = "ModuleId" -GATEWAY_HOST_NAME = "GatewayHostName" - -_valid_keys = [ - HOST_NAME, - SHARED_ACCESS_KEY_NAME, - SHARED_ACCESS_KEY, - SHARED_ACCESS_SIGNATURE, - DEVICE_ID, - MODULE_ID, - GATEWAY_HOST_NAME, -] - - -class SymmetricKeyAuthenticationProvider(BaseRenewableTokenAuthenticationProvider): - """ - A Symmetric Key Authentication Provider. This provider needs to create the i - Shared Access Signature that would be needed to connect to the IoT Hub. - """ - - def __init__( - self, - hostname, - device_id, - module_id, - shared_access_key, - shared_access_key_name=None, - gateway_hostname=None, - ): - """ - - Constructor for SymmetricKey Authentication Provider - """ - logger.info( - "Using Shared Key authentication for {%s, %s, %s}", hostname, device_id, module_id - ) - - super(SymmetricKeyAuthenticationProvider, self).__init__( - hostname=hostname, device_id=device_id, module_id=module_id - ) - self.shared_access_key = shared_access_key - self.shared_access_key_name = shared_access_key_name - self.gateway_hostname = gateway_hostname - self.server_verification_cert = None - - @staticmethod - def parse(connection_string): - """ - This method creates a Symmetric Key Authentication Provider from a given connection string, and sets properties for each of the parsed - fields in the string. Also validates the required properties of the connection string. - :param connection_string: The semicolon-delimited string of 'name=value' pairs. - The input may look like the following formations:- - HostName=;DeviceId=;SharedAccessKey= - HostName=;DeviceId=;SharedAccessKeyName=;SharedAccessKey= - HostName=;DeviceId=;ModuleId=;SharedAccessKey= - :return: The Symmetric Key Authentication Provider constructed - """ - try: - cs_args = connection_string.split(DELIMITER) - d = dict(arg.split(VALUE_SEPARATOR, 1) for arg in cs_args) - except (ValueError, AttributeError): - raise ValueError( - "Connection string is required and should not be empty or blank and must be supplied as a string" - ) - if len(cs_args) != len(d): - raise ValueError("Invalid Connection String - Unable to parse") - if not all(key in _valid_keys for key in d.keys()): - raise ValueError("Invalid Connection String - Invalid Key") - - _validate_keys(d) - - return SymmetricKeyAuthenticationProvider( - d.get(HOST_NAME), - d.get(DEVICE_ID), - d.get(MODULE_ID), - d.get(SHARED_ACCESS_KEY), - d.get(SHARED_ACCESS_KEY_NAME), - d.get(GATEWAY_HOST_NAME), - ) - - def _sign(self, quoted_resource_uri, expiry): - """ - Creates the base64-encoded HMAC-SHA256 hash of the string to sign. The string to sign is constructed from the - resource_uri and expiry and the signing key is constructed from the device_key. - :param quoted_resource_uri: the resource URI to encode into the token, already URI-encoded - :param expiry: an integer value representing the number of seconds since the epoch 00:00:00 UTC on 1 January 1970 at which the token will expire. - :return: The signature portion of the Sas Token. - """ - try: - message = (quoted_resource_uri + "\n" + str(expiry)).encode("utf-8") - signing_key = base64.b64decode(self.shared_access_key.encode("utf-8")) - signed_hmac = hmac.HMAC(signing_key, message, hashlib.sha256) - signature = urllib.parse.quote(base64.b64encode(signed_hmac.digest())) - except (TypeError, base64.binascii.Error): - raise ValueError("Unable to build shared access signature from given values") - return signature - - -def _validate_keys(d): - """Raise ValueError if incorrect combination of keys - """ - host_name = d.get(HOST_NAME) - shared_access_key_name = d.get(SHARED_ACCESS_KEY_NAME) - shared_access_key = d.get(SHARED_ACCESS_KEY) - device_id = d.get(DEVICE_ID) - - if host_name and device_id and shared_access_key: - pass - elif host_name and shared_access_key and shared_access_key_name: - pass - else: - raise ValueError("Invalid Connection String - Incomplete") diff --git a/azure-iot-device/azure/iot/device/iothub/auth/x509_authentication_provider.py b/azure-iot-device/azure/iot/device/iothub/auth/x509_authentication_provider.py deleted file mode 100644 index f352ff72d..000000000 --- a/azure-iot-device/azure/iot/device/iothub/auth/x509_authentication_provider.py +++ /dev/null @@ -1,43 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import logging -from .authentication_provider import AuthenticationProvider - -logger = logging.getLogger(__name__) - - -class X509AuthenticationProvider(AuthenticationProvider): - """ - An X509 Authentication Provider. This provider uses the certificate and key provided to - authenticate a device with an Azure IoT Hub instance. X509 Authentication is only supported - for device identities connecting directly to an Azure IoT hub. - """ - - def __init__(self, x509, hostname, device_id, module_id=None): - """ - Constructor for X509 Authentication Provider - :param x509: The X509 object containing certificate, key and passphrase - :param hostname: The hostname of the Azure IoT hub. - :param device_id: The device unique identifier as it exists in the Azure IoT Hub device registry. - :param module_id: The module unique identifier of the device. It is not applicable when dealing with only devices. - """ - logger.info( - "Using X509 authentication for {hostname},{device_id},{module_id}".format( - hostname=hostname, device_id=device_id, module_id=module_id - ) - ) - super(X509AuthenticationProvider, self).__init__( - hostname=hostname, device_id=device_id, module_id=module_id - ) - self._x509 = x509 - - def get_x509_certificate(self): - """ - :return: The x509 certificate, To use the certificate the enrollment object needs to contain - cert (either the root certificate or one of the intermediate CA certificates). - """ - return self._x509 diff --git a/azure-iot-device/azure/iot/device/iothub/auth/iotedge_authentication_provider.py b/azure-iot-device/azure/iot/device/iothub/edge_hsm.py similarity index 61% rename from azure-iot-device/azure/iot/device/iothub/auth/iotedge_authentication_provider.py rename to azure-iot-device/azure/iot/device/iothub/edge_hsm.py index 9bdd19e29..f6c802466 100644 --- a/azure-iot-device/azure/iot/device/iothub/auth/iotedge_authentication_provider.py +++ b/azure-iot-device/azure/iot/device/iothub/edge_hsm.py @@ -1,22 +1,20 @@ -# ------------------------------------------------------------------------- +# -------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- -import os -import base64 +import logging import json -import six.moves.urllib as urllib +import base64 import requests import requests_unixsocket -import logging -from .base_renewable_token_authentication_provider import BaseRenewableTokenAuthenticationProvider +from six.moves import urllib, http_client from azure.iot.device.common.chainable_exception import ChainableException -from azure.iot.device.product_info import ProductInfo +from azure.iot.device.common.auth.signing_mechanism import SigningMechanism +from azure.iot.device import user_agent requests_unixsocket.monkeypatch() - logger = logging.getLogger(__name__) @@ -24,57 +22,7 @@ class IoTEdgeError(ChainableException): pass -class IoTEdgeAuthenticationProvider(BaseRenewableTokenAuthenticationProvider): - """An Azure IoT Edge Authentication Provider. - - This provider creates the Shared Access Signature that would be needed to connenct to the IoT Edge runtime - """ - - def __init__( - self, - hostname, - device_id, - module_id, - gateway_hostname, - module_generation_id, - workload_uri, - api_version, - ): - """ - Constructor for IoT Edge Authentication Provider - """ - - logger.info("Using IoTEdge authentication for {%s, %s, %s}", hostname, device_id, module_id) - - super(IoTEdgeAuthenticationProvider, self).__init__( - hostname=hostname, device_id=device_id, module_id=module_id - ) - - self.hsm = IoTEdgeHsm( - module_id=module_id, - api_version=api_version, - module_generation_id=module_generation_id, - workload_uri=workload_uri, - ) - self.gateway_hostname = gateway_hostname - self.server_verification_cert = self.hsm.get_trust_bundle() - - # TODO: reconsider this design when refactoring the BaseRenewableToken auth parent - # TODO: Consider handling the quoting within this function, and renaming quoted_resource_uri to resource_uri - def _sign(self, quoted_resource_uri, expiry): - """ - Creates the signature to be inserted in the SAS token - :param resource_uri: the resource URI to encode into the token - :param expiry: an integer value representing the number of seconds since the epoch 00:00:00 UTC on 1 January 1970 at which the token will expire. - :return: The signature portion of the Sas Token. - - :raises: IoTEdgeError if data cannot be signed - """ - string_to_sign = quoted_resource_uri + "\n" + str(expiry) - return self.hsm.sign(string_to_sign) - - -class IoTEdgeHsm(object): +class IoTEdgeHsm(SigningMechanism): """ Constructor for instantiating a iot hsm object. This is an object that communicates with the Azure IoT Edge HSM in order to get connection credentials @@ -87,26 +35,24 @@ class IoTEdgeHsm(object): SharedAccessSignature string which can be used to authenticate with Iot Edge """ - def __init__(self, module_id, module_generation_id, workload_uri, api_version): + def __init__(self, module_id, generation_id, workload_uri, api_version): """ Constructor for instantiating a Azure IoT Edge HSM object :param str module_id: The module id :param str api_version: The API version - :param str module_generation_id: The module generation id + :param str generation_id: The module generation id :param str workload_uri: The workload uri """ - self.module_id = urllib.parse.quote(module_id) + self.module_id = urllib.parse.quote(module_id, safe="") self.api_version = api_version - self.module_generation_id = module_generation_id + self.generation_id = generation_id self.workload_uri = _format_socket_uri(workload_uri) - # TODO: Is this really the right name? It returns a certificate FROM the trust bundle, - # not the trust bundle itself - def get_trust_bundle(self): + def get_certificate(self): """ - Return the trust bundle that can be used to validate the server-side SSL - TLS connection that we use to talk to edgeHub. + Return the server verification certificate from the trust bundle that can be used to + validate the server-side SSL TLS connection that we use to talk to Edge :return: The server verification certificate to use for connections to the Azure IoT Edge instance, as a PEM certificate in string form. @@ -116,13 +62,13 @@ def get_trust_bundle(self): r = requests.get( self.workload_uri + "trust-bundle", params={"api-version": self.api_version}, - headers={"User-Agent": urllib.parse.quote_plus(ProductInfo.get_iothub_user_agent())}, + headers={"User-Agent": urllib.parse.quote_plus(user_agent.get_iothub_user_agent())}, ) # Validate that the request was successful try: r.raise_for_status() except requests.exceptions.HTTPError as e: - raise IoTEdgeError(message="Unable to get trust bundle from EdgeHub", cause=e) + raise IoTEdgeError(message="Unable to get trust bundle from Edge", cause=e) # Decode the trust bundle try: bundle = r.json() @@ -149,20 +95,15 @@ def sign(self, data_str): """ encoded_data_str = base64.b64encode(data_str.encode("utf-8")).decode() - path = ( - self.workload_uri - + "modules/" - + self.module_id - + "/genid/" - + self.module_generation_id - + "/sign" + path = "{workload_uri}modules/{module_id}/genid/{gen_id}/sign".format( + workload_uri=self.workload_uri, module_id=self.module_id, gen_id=self.generation_id ) sign_request = {"keyId": "primary", "algo": "HMACSHA256", "data": encoded_data_str} - r = requests.post( # TODO: can we use json field instead of data? + r = requests.post( # can we use json field instead of data? url=path, params={"api-version": self.api_version}, - headers={"User-Agent": urllib.parse.quote_plus(ProductInfo.get_iothub_user_agent())}, + headers={"User-Agent": urllib.parse.quote(user_agent.get_iothub_user_agent(), safe="")}, data=json.dumps(sign_request), ) try: @@ -178,7 +119,7 @@ def sign(self, data_str): except KeyError as e: raise IoTEdgeError(message="No signed data received", cause=e) - return urllib.parse.quote(signed_data_str) + return signed_data_str # what format is this? string? bytes? def _format_socket_uri(old_uri): diff --git a/azure-iot-device/azure/iot/device/iothub/pipeline/config.py b/azure-iot-device/azure/iot/device/iothub/pipeline/config.py index 907a9f622..b3d7a4c53 100644 --- a/azure-iot-device/azure/iot/device/iothub/pipeline/config.py +++ b/azure-iot-device/azure/iot/device/iothub/pipeline/config.py @@ -14,14 +14,23 @@ class IoTHubPipelineConfig(BasePipelineConfig): """A class for storing all configurations/options for IoTHub clients in the Azure IoT Python Device Client Library. """ - def __init__(self, product_info="", **kwargs): + def __init__(self, hostname, device_id, module_id=None, product_info="", **kwargs): """Initializer for IoTHubPipelineConfig which passes all unrecognized keyword-args down to BasePipelineConfig to be evaluated. This stacked options setting is to allow for unique configuration options to exist between the - IoTHub Client and the Provisioning Client, while maintaining a base configuration class with shared config options. + multiple clients, while maintaining a base configuration class with shared config options. + :param str hostname: The hostname of the IoTHub to connect to + :param str device_id: The device identity being used with the IoTHub + :param str module_id: The module identity being used with the IoTHub :param str product_info: A custom identification string for the type of device connecting to Azure IoT Hub. """ - super(IoTHubPipelineConfig, self).__init__(**kwargs) + super(IoTHubPipelineConfig, self).__init__(hostname=hostname, **kwargs) + + # IoTHub Connection Details + self.device_id = device_id + self.module_id = module_id + + # Product Info self.product_info = product_info # Now, the parameters below are not exposed to the user via kwargs. They need to be set by manipulating the IoTHubPipelineConfig object. diff --git a/azure-iot-device/azure/iot/device/iothub/pipeline/http_pipeline.py b/azure-iot-device/azure/iot/device/iothub/pipeline/http_pipeline.py index ad8281c4b..35503bb23 100644 --- a/azure-iot-device/azure/iot/device/iothub/pipeline/http_pipeline.py +++ b/azure-iot-device/azure/iot/device/iothub/pipeline/http_pipeline.py @@ -22,7 +22,6 @@ pipeline_ops_iothub_http, pipeline_stages_iothub_http, ) -from azure.iot.device.iothub.auth.x509_authentication_provider import X509AuthenticationProvider logger = logging.getLogger(__name__) @@ -32,7 +31,7 @@ class HTTPPipeline(object): Uses HTTP. """ - def __init__(self, auth_provider, pipeline_configuration): + def __init__(self, pipeline_configuration): """ Constructor for instantiating a pipeline adapter object. @@ -40,22 +39,15 @@ def __init__(self, auth_provider, pipeline_configuration): :param pipeline_configuration: The configuration generated based on user inputs """ self._pipeline = ( - pipeline_stages_base.PipelineRootStage(pipeline_configuration=pipeline_configuration) - .append_stage(pipeline_stages_iothub.UseAuthProviderStage()) + pipeline_stages_base.PipelineRootStage(pipeline_configuration) + .append_stage(pipeline_stages_base.SasTokenRenewalStage()) .append_stage(pipeline_stages_iothub_http.IoTHubHTTPTranslationStage()) .append_stage(pipeline_stages_http.HTTPTransportStage()) ) callback = EventedCallback() - if isinstance(auth_provider, X509AuthenticationProvider): - op = pipeline_ops_iothub.SetX509AuthProviderOperation( - auth_provider=auth_provider, callback=callback - ) - else: # Currently everything else goes via this block. - op = pipeline_ops_iothub.SetAuthProviderOperation( - auth_provider=auth_provider, callback=callback - ) + op = pipeline_ops_base.InitializePipelineOperation(callback=callback) self._pipeline.run_op(op) callback.wait_for_completion() diff --git a/azure-iot-device/azure/iot/device/iothub/pipeline/mqtt_pipeline.py b/azure-iot-device/azure/iot/device/iothub/pipeline/mqtt_pipeline.py index b823b19c8..9ffaf29b3 100644 --- a/azure-iot-device/azure/iot/device/iothub/pipeline/mqtt_pipeline.py +++ b/azure-iot-device/azure/iot/device/iothub/pipeline/mqtt_pipeline.py @@ -19,13 +19,12 @@ pipeline_ops_iothub, pipeline_stages_iothub_mqtt, ) -from azure.iot.device.iothub.auth.x509_authentication_provider import X509AuthenticationProvider logger = logging.getLogger(__name__) class MQTTPipeline(object): - def __init__(self, auth_provider, pipeline_configuration): + def __init__(self, pipeline_configuration): """ Constructor for instantiating a pipeline adapter object :param auth_provider: The authentication provider @@ -54,12 +53,12 @@ def __init__(self, auth_provider, pipeline_configuration): # # The root is always the root. By definition, it's the first stage in the pipeline. # - pipeline_stages_base.PipelineRootStage(pipeline_configuration=pipeline_configuration) + pipeline_stages_base.PipelineRootStage(pipeline_configuration) # - # UseAuthProviderStage comes near the root by default because it doesn't need to be after - # anything, but it does need to be before IoTHubMQTTTranslationStage. + # SasTokenRenewalStage comes near the root by default because it should be as close + # to the top of the pipeline as possible, and does not need to be after anything. # - .append_stage(pipeline_stages_iothub.UseAuthProviderStage()) + .append_stage(pipeline_stages_base.SasTokenRenewalStage()) # # EnsureDesiredPropertiesStage needs to be above TwinRequestResponseStage because it # sends GetTwinOperation ops and that stage handles those ops. @@ -162,14 +161,10 @@ def _on_disconnected(): callback = EventedCallback() - if isinstance(auth_provider, X509AuthenticationProvider): - op = pipeline_ops_iothub.SetX509AuthProviderOperation( - auth_provider=auth_provider, callback=callback - ) - else: # Currently everything else goes via this block. - op = pipeline_ops_iothub.SetAuthProviderOperation( - auth_provider=auth_provider, callback=callback - ) + # NOTE: It would be nice if this didn't have to go down as a dynamic operation. + # At this time we haven't been able to figure out a better way to make it work though. + + op = pipeline_ops_base.InitializePipelineOperation(callback=callback) self._pipeline.run_op(op) callback.wait_for_completion() @@ -236,7 +231,7 @@ def on_complete(op, error): pipeline_ops_iothub.SendD2CMessageOperation(message=message, callback=on_complete) ) - def send_output_event(self, message, callback): + def send_output_message(self, message, callback): """ Send an output message to the service. @@ -256,7 +251,7 @@ def on_complete(op, error): callback(error=error) self._pipeline.run_op( - pipeline_ops_iothub.SendOutputEventOperation(message=message, callback=on_complete) + pipeline_ops_iothub.SendOutputMessageOperation(message=message, callback=on_complete) ) def send_method_response(self, method_response, callback): diff --git a/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_ops_iothub.py b/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_ops_iothub.py index 117a8c731..75a44f7b3 100644 --- a/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_ops_iothub.py +++ b/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_ops_iothub.py @@ -6,104 +6,6 @@ from azure.iot.device.common.pipeline import PipelineOperation -# TODO: Combine SetAuthProviderOperation and SetX509AuthProviderOperation once -# auth provider is reduced to a simple vector -class SetX509AuthProviderOperation(PipelineOperation): - """ - A PipelineOperation object which tells the pipeline to use a particular X509 authorization provider. - Some pipeline stage is expected to extract arguments out of the auth provider and pass them - on so an even lower stage can use those arguments to connect. - - This operation is in the group of IoTHub operations because authorization providers are currently - very IoTHub-specific - """ - - def __init__(self, auth_provider, callback): - """ - Initializer for SetAuthProviderOperation objects. - - :param object auth_provider: The X509 authorization provider object to use to retrieve connection parameters - which can be used to connect to the service. - :param Function callback: The function that gets called when this operation is complete or has failed. - The callback function must accept A PipelineOperation object which indicates the specific operation which - has completed or failed. - """ - super(SetX509AuthProviderOperation, self).__init__(callback=callback) - self.auth_provider = auth_provider - - -class SetAuthProviderOperation(PipelineOperation): - """ - A PipelineOperation object which tells the pipeline to use a particular authorization provider. - Some pipeline stage is expected to extract arguments out of the auth provider and pass them - on so an even lower stage can use those arguments to connect. - - This operation is in the group of IoTHub operations because autorization providers are currently - very IoTHub-specific - """ - - def __init__(self, auth_provider, callback): - """ - Initializer for SetAuthProviderOperation objects. - - :param object auth_provider: The authorization provider object to use to retrieve connection parameters - which can be used to connect to the service. - :param Function callback: The function that gets called when this operation is complete or has failed. - The callback function must accept A PipelineOperation object which indicates the specific operation which - has completed or failed. - """ - super(SetAuthProviderOperation, self).__init__(callback=callback) - self.auth_provider = auth_provider - - -class SetIoTHubConnectionArgsOperation(PipelineOperation): - """ - A PipelineOperation object which contains connection arguments which were retrieved from an authorization provider, - likely by a pipeline stage which handles the SetAuthProviderOperation operation. - - This operation is in the group of IoTHub operations because the arguments which it accepts are very specific to - IoTHub connections and would not apply to other types of client connections (such as a DPS client). - """ - - def __init__( - self, - device_id, - hostname, - callback, - module_id=None, - gateway_hostname=None, - server_verification_cert=None, - client_cert=None, - sas_token=None, - ): - """ - Initializer for SetIoTHubConnectionArgsOperation objects. - - :param str device_id: The device id for the device that we are connecting. - :param str hostname: The hostname of the iothub service we are connecting to. - :param str module_id: (optional) If we are connecting as a module, this contains the module id - for the module we are connecting. - :param str gateway_hostname: (optional) If we are going through a gateway host, this is the - hostname for the gateway - :param str server_verification_cert: (Optional) The server verification certificate to use - if the server that we're going to connect to uses server-side TLS - :param X509 client_cert: (Optional) The x509 object containing a client certificate and key used to connect - to the service - :param str sas_token: The token string which will be used to authenticate with the service - :param Function callback: The function that gets called when this operation is complete or has failed. - The callback function must accept A PipelineOperation object which indicates the specific operation which - has completed or failed. - """ - super(SetIoTHubConnectionArgsOperation, self).__init__(callback=callback) - self.device_id = device_id - self.module_id = module_id - self.hostname = hostname - self.gateway_hostname = gateway_hostname - self.server_verification_cert = server_verification_cert - self.client_cert = client_cert - self.sas_token = sas_token - - class SendD2CMessageOperation(PipelineOperation): """ A PipelineOperation object which contains arguments used to send a telemetry message to an IoTHub or EdegHub server. @@ -124,7 +26,7 @@ def __init__(self, message, callback): self.message = message -class SendOutputEventOperation(PipelineOperation): +class SendOutputMessageOperation(PipelineOperation): """ A PipelineOperation object which contains arguments used to send an output message to an EdgeHub server. @@ -133,7 +35,7 @@ class SendOutputEventOperation(PipelineOperation): def __init__(self, message, callback): """ - Initializer for SendOutputEventOperation objects. + Initializer for SendOutputMessageOperation objects. :param Message message: The output message that we're sending to the service. The name of the output is expected to be stored in the output_name attribute of this object @@ -141,7 +43,7 @@ def __init__(self, message, callback): The callback function must accept A PipelineOperation object which indicates the specific operation which has completed or failed. """ - super(SendOutputEventOperation, self).__init__(callback=callback) + super(SendOutputMessageOperation, self).__init__(callback=callback) self.message = message diff --git a/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_stages_iothub.py b/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_stages_iothub.py index caf71302b..f4e02b6c1 100644 --- a/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_stages_iothub.py +++ b/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_stages_iothub.py @@ -21,85 +21,6 @@ logger = logging.getLogger(__name__) -class UseAuthProviderStage(PipelineStage): - """ - PipelineStage which extracts relevant AuthenticationProvider values for a new - SetIoTHubConnectionArgsOperation. - - All other operations are passed down. - """ - - def __init__(self): - super(UseAuthProviderStage, self).__init__() - self.auth_provider = None - - @pipeline_thread.runs_on_pipeline_thread - def _run_op(self, op): - if isinstance(op, pipeline_ops_iothub.SetAuthProviderOperation): - self.auth_provider = op.auth_provider - # Here we append rather than just add it to the handler value because otherwise it - # would overwrite the handler from another pipeline that might be using the same auth provider. - self.auth_provider.on_sas_token_updated_handler_list.append( - CallableWeakMethod(self, "_on_sas_token_updated") - ) - worker_op = op.spawn_worker_op( - worker_op_type=pipeline_ops_iothub.SetIoTHubConnectionArgsOperation, - device_id=self.auth_provider.device_id, - module_id=self.auth_provider.module_id, - hostname=self.auth_provider.hostname, - gateway_hostname=getattr(self.auth_provider, "gateway_hostname", None), - server_verification_cert=getattr( - self.auth_provider, "server_verification_cert", None - ), - sas_token=self.auth_provider.get_current_sas_token(), - ) - self.send_op_down(worker_op) - - elif isinstance(op, pipeline_ops_iothub.SetX509AuthProviderOperation): - self.auth_provider = op.auth_provider - worker_op = op.spawn_worker_op( - worker_op_type=pipeline_ops_iothub.SetIoTHubConnectionArgsOperation, - device_id=self.auth_provider.device_id, - module_id=self.auth_provider.module_id, - hostname=self.auth_provider.hostname, - gateway_hostname=getattr(self.auth_provider, "gateway_hostname", None), - server_verification_cert=getattr( - self.auth_provider, "server_verification_cert", None - ), - client_cert=self.auth_provider.get_x509_certificate(), - ) - self.send_op_down(worker_op) - else: - super(UseAuthProviderStage, self)._run_op(op) - - @pipeline_thread.invoke_on_pipeline_thread_nowait - def _on_sas_token_updated(self): - logger.info( - "{}: New sas token received. Passing down UpdateSasTokenOperation.".format(self.name) - ) - - @pipeline_thread.runs_on_pipeline_thread - def on_token_update_complete(op, error): - if error: - logger.error( - "{}({}): token update operation failed. Error={}".format( - self.name, op.name, error - ) - ) - handle_exceptions.handle_background_exception(error) - else: - logger.debug( - "{}({}): token update operation is complete".format(self.name, op.name) - ) - - self.send_op_down( - pipeline_ops_base.UpdateSasTokenOperation( - sas_token=self.auth_provider.get_current_sas_token(), - callback=on_token_update_complete, - ) - ) - - class EnsureDesiredPropertiesStage(PipelineStage): """ Pipeline stage Responsible for making sure that desired properties are always kept up to date. diff --git a/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_stages_iothub_http.py b/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_stages_iothub_http.py index 7b156488b..c8bbdf9a2 100644 --- a/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_stages_iothub_http.py +++ b/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_stages_iothub_http.py @@ -17,7 +17,7 @@ from . import pipeline_ops_iothub, pipeline_ops_iothub_http, http_path_iothub, http_map_error from azure.iot.device import exceptions from azure.iot.device import constant as pkg_constant -from azure.iot.device.product_info import ProductInfo +from azure.iot.device import user_agent logger = logging.getLogger(__name__) @@ -40,40 +40,9 @@ class IoTHubHTTPTranslationStage(PipelineStage): converts http pipeline events into Iot and EdgeHub pipeline events. """ - def __init__(self): - super(IoTHubHTTPTranslationStage, self).__init__() - self.device_id = None - self.module_id = None - self.hostname = None - @pipeline_thread.runs_on_pipeline_thread def _run_op(self, op): - if isinstance(op, pipeline_ops_iothub.SetIoTHubConnectionArgsOperation): - self.device_id = op.device_id - self.module_id = op.module_id - - if op.gateway_hostname: - logger.debug( - "Gateway Hostname Present. Setting Hostname to: {}".format(op.gateway_hostname) - ) - self.hostname = op.gateway_hostname - else: - logger.debug( - "Gateway Hostname not present. Setting Hostname to: {}".format( - op.gateway_hostname - ) - ) - self.hostname = op.hostname - worker_op = op.spawn_worker_op( - worker_op_type=pipeline_ops_http.SetHTTPConnectionArgsOperation, - hostname=self.hostname, - server_verification_cert=op.server_verification_cert, - client_cert=op.client_cert, - sas_token=op.sas_token, - ) - self.send_op_down(worker_op) - - elif isinstance(op, pipeline_ops_iothub_http.MethodInvokeOperation): + if isinstance(op, pipeline_ops_iothub_http.MethodInvokeOperation): logger.debug( "{}({}): Translating Method Invoke Operation for HTTP.".format(self.name, op.name) ) @@ -84,22 +53,23 @@ def _run_op(self, op): body = json.dumps(op.method_params) path = http_path_iothub.get_method_invoke_path(op.target_device_id, op.target_module_id) - # Note we do not add the sas Authorization header here. Instead we add it later on in the stage above - # the transport layer, since that stage stores the updated SAS and also X509 certs if that is what is - # being used. + # NOTE: we do not add the sas Authorization header here. Instead we add it later on in + # the HTTPTransportStage x_ms_edge_string = "{deviceId}/{moduleId}".format( - deviceId=self.device_id, moduleId=self.module_id + deviceId=self.pipeline_root.pipeline_configuration.device_id, + moduleId=self.pipeline_root.pipeline_configuration.module_id, ) # these are the identifiers of the current module - user_agent = urllib.parse.quote_plus( - ProductInfo.get_iothub_user_agent() + user_agent_string = urllib.parse.quote_plus( + user_agent.get_iothub_user_agent() + str(self.pipeline_root.pipeline_configuration.product_info) ) + # Method Invoke must be addressed to the gateway hostname because it is an Edge op headers = { - "Host": self.hostname, + "Host": self.pipeline_root.pipeline_configuration.gateway_hostname, "Content-Type": "application/json", "Content-Length": len(str(body)), "x-ms-edge-moduleId": x_ms_edge_string, - "User-Agent": user_agent, + "User-Agent": user_agent_string, } op_waiting_for_response = op @@ -132,18 +102,20 @@ def on_request_response(op, error): query_params = "api-version={apiVersion}".format( apiVersion=pkg_constant.IOTHUB_API_VERSION ) - path = http_path_iothub.get_storage_info_for_blob_path(self.device_id) + path = http_path_iothub.get_storage_info_for_blob_path( + self.pipeline_root.pipeline_configuration.device_id + ) body = json.dumps({"blobName": op.blob_name}) - user_agent = urllib.parse.quote_plus( - ProductInfo.get_iothub_user_agent() + user_agent_string = urllib.parse.quote_plus( + user_agent.get_iothub_user_agent() + str(self.pipeline_root.pipeline_configuration.product_info) ) headers = { - "Host": self.hostname, + "Host": self.pipeline_root.pipeline_configuration.hostname, "Accept": "application/json", "Content-Type": "application/json", "Content-Length": len(str(body)), - "User-Agent": user_agent, + "User-Agent": user_agent_string, } op_waiting_for_response = op @@ -177,7 +149,9 @@ def on_request_response(op, error): query_params = "api-version={apiVersion}".format( apiVersion=pkg_constant.IOTHUB_API_VERSION ) - path = http_path_iothub.get_notify_blob_upload_status_path(self.device_id) + path = http_path_iothub.get_notify_blob_upload_status_path( + self.pipeline_root.pipeline_configuration.device_id + ) body = json.dumps( { "correlationId": op.correlation_id, @@ -186,19 +160,18 @@ def on_request_response(op, error): "statusDescription": op.status_description, } ) - user_agent = urllib.parse.quote_plus( - ProductInfo.get_iothub_user_agent() + user_agent_string = urllib.parse.quote_plus( + user_agent.get_iothub_user_agent() + str(self.pipeline_root.pipeline_configuration.product_info) ) - # Note we do not add the sas Authorization header here. Instead we add it later on in the stage above - # the transport layer, since that stage stores the updated SAS and also X509 certs if that is what is - # being used. + # NOTE we do not add the sas Authorization header here. Instead we add it later on in + # the HTTPTransportStage headers = { - "Host": self.hostname, + "Host": self.pipeline_root.pipeline_configuration.hostname, "Content-Type": "application/json; charset=utf-8", "Content-Length": len(str(body)), - "User-Agent": user_agent, + "User-Agent": user_agent_string, } op_waiting_for_response = op diff --git a/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_stages_iothub_mqtt.py b/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_stages_iothub_mqtt.py index 4bd46b6e6..c3661d131 100644 --- a/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_stages_iothub_mqtt.py +++ b/azure-iot-device/azure/iot/device/iothub/pipeline/pipeline_stages_iothub_mqtt.py @@ -21,7 +21,7 @@ from . import constant as pipeline_constant from . import exceptions as pipeline_exceptions from azure.iot.device import constant as pkg_constant -from azure.iot.device.product_info import ProductInfo +from azure.iot.device import user_agent logger = logging.getLogger(__name__) @@ -32,108 +32,55 @@ class IoTHubMQTTTranslationStage(PipelineStage): converts mqtt pipeline events into Iot and IoTHub pipeline events. """ - def __init__(self): - super(IoTHubMQTTTranslationStage, self).__init__() - self.feature_to_topic = {} - self.device_id = None - self.module_id = None - @pipeline_thread.runs_on_pipeline_thread def _run_op(self, op): - if isinstance(op, pipeline_ops_iothub.SetIoTHubConnectionArgsOperation): - self.device_id = op.device_id - self.module_id = op.module_id - - # if we get auth provider args from above, we save some, use some to build topic names, - # and always pass it down because we know that the MQTT protocol stage will also want - # to receive these args. - self._set_topic_names(device_id=op.device_id, module_id=op.module_id) + if isinstance(op, pipeline_ops_base.InitializePipelineOperation): - if op.module_id: - client_id = "{}/{}".format(op.device_id, op.module_id) + if self.pipeline_root.pipeline_configuration.module_id: + # Module Format + client_id = "{}/{}".format( + self.pipeline_root.pipeline_configuration.device_id, + self.pipeline_root.pipeline_configuration.module_id, + ) else: - client_id = op.device_id + # Device Format + client_id = self.pipeline_root.pipeline_configuration.device_id - # For MQTT, the entire user agent string should be appended to the username field in the connect packet - # For example, the username may look like this without custom parameters: - # yosephsandboxhub.azure-devices.net/alpha/?api-version=2018-06-30&DeviceClientType=py-azure-iot-device%2F2.0.0-preview.12 - # The customer user agent string would simply be appended to the end of this username, in URL Encoded format. + # Apply query parameters (i.e. key1=value1&key2=value2...&keyN=valueN format) query_param_seq = [ ("api-version", pkg_constant.IOTHUB_API_VERSION), ( "DeviceClientType", - ProductInfo.get_iothub_user_agent() - + str(self.pipeline_root.pipeline_configuration.product_info), + user_agent.get_iothub_user_agent() + + self.pipeline_root.pipeline_configuration.product_info, ), ] username = "{hostname}/{client_id}/?{query_params}".format( - hostname=op.hostname, + hostname=self.pipeline_root.pipeline_configuration.hostname, client_id=client_id, query_params=version_compat.urlencode( query_param_seq, quote_via=urllib.parse.quote ), ) - if op.gateway_hostname: - hostname = op.gateway_hostname - else: - hostname = op.hostname - - # TODO: test to make sure client_cert and sas_token travel down correctly - worker_op = op.spawn_worker_op( - worker_op_type=pipeline_ops_mqtt.SetMQTTConnectionArgsOperation, - client_id=client_id, - hostname=hostname, - username=username, - server_verification_cert=op.server_verification_cert, - client_cert=op.client_cert, - sas_token=op.sas_token, - ) - self.send_op_down(worker_op) - - elif ( - isinstance(op, pipeline_ops_base.UpdateSasTokenOperation) - and self.pipeline_root.connected - ): - logger.debug( - "{}({}): Connected. Passing op down and reauthorizing after token is updated.".format( - self.name, op.name - ) - ) - - # make a callback that either fails the UpdateSasTokenOperation (if the lower level failed it), - # or issues a ReauthorizeConnectionOperation (if the lower level returned success for the UpdateSasTokenOperation) - def on_token_update_complete(op, error): - if error: - logger.error( - "{}({}) token update failed. returning failure {}".format( - self.name, op.name, error - ) - ) - else: - logger.debug( - "{}({}) token update succeeded. reauthorizing".format(self.name, op.name) - ) + # Dynamically attach the derived MQTT values to the InitalizePipelineOperation + # to be used later down the pipeline + op.username = username + op.client_id = client_id - # Stop completion of Token Update op, and only continue upon completion of ReauthorizeConnectionOperation - op.halt_completion() - worker_op = op.spawn_worker_op( - worker_op_type=pipeline_ops_base.ReauthorizeConnectionOperation - ) - - self.send_op_down(worker_op) - - # now, pass the UpdateSasTokenOperation down with our new callback. - op.add_callback(on_token_update_complete) self.send_op_down(op) elif isinstance(op, pipeline_ops_iothub.SendD2CMessageOperation) or isinstance( - op, pipeline_ops_iothub.SendOutputEventOperation + op, pipeline_ops_iothub.SendOutputMessageOperation ): - # Convert SendTelementry and SendOutputEventOperation operations into MQTT Publish operations + # Convert SendTelementry and SendOutputMessageOperation operations into MQTT Publish operations + telemetry_topic = mqtt_topic_iothub.get_telemetry_topic_for_publish( + device_id=self.pipeline_root.pipeline_configuration.device_id, + module_id=self.pipeline_root.pipeline_configuration.module_id, + ) topic = mqtt_topic_iothub.encode_message_properties_in_topic( - op.message, self.telemetry_topic + op.message, telemetry_topic ) worker_op = op.spawn_worker_op( worker_op_type=pipeline_ops_mqtt.MQTTPublishOperation, @@ -145,7 +92,7 @@ def on_token_update_complete(op, error): elif isinstance(op, pipeline_ops_iothub.SendMethodResponseOperation): # Sending a Method Response gets translated into an MQTT Publish operation topic = mqtt_topic_iothub.get_method_topic_for_publish( - op.method_response.request_id, str(op.method_response.status) + op.method_response.request_id, op.method_response.status ) payload = json.dumps(op.method_response.payload) worker_op = op.spawn_worker_op( @@ -155,7 +102,7 @@ def on_token_update_complete(op, error): elif isinstance(op, pipeline_ops_base.EnableFeatureOperation): # Enabling a feature gets translated into an MQTT subscribe operation - topic = self.feature_to_topic[op.feature_name] + topic = self._get_feature_subscription_topic(op.feature_name) worker_op = op.spawn_worker_op( worker_op_type=pipeline_ops_mqtt.MQTTSubscribeOperation, topic=topic ) @@ -163,7 +110,7 @@ def on_token_update_complete(op, error): elif isinstance(op, pipeline_ops_base.DisableFeatureOperation): # Disabling a feature gets turned into an MQTT unsubscribe operation - topic = self.feature_to_topic[op.feature_name] + topic = self._get_feature_subscription_topic(op.feature_name) worker_op = op.spawn_worker_op( worker_op_type=pipeline_ops_mqtt.MQTTUnsubscribeOperation, topic=topic ) @@ -192,24 +139,27 @@ def on_token_update_complete(op, error): super(IoTHubMQTTTranslationStage, self)._run_op(op) @pipeline_thread.runs_on_pipeline_thread - def _set_topic_names(self, device_id, module_id): - """ - Build topic names based on the device_id and module_id passed. - """ - self.telemetry_topic = mqtt_topic_iothub.get_telemetry_topic_for_publish( - device_id, module_id - ) - self.feature_to_topic = { - pipeline_constant.C2D_MSG: (mqtt_topic_iothub.get_c2d_topic_for_subscribe(device_id)), - pipeline_constant.INPUT_MSG: ( - mqtt_topic_iothub.get_input_topic_for_subscribe(device_id, module_id) - ), - pipeline_constant.METHODS: (mqtt_topic_iothub.get_method_topic_for_subscribe()), - pipeline_constant.TWIN: (mqtt_topic_iothub.get_twin_response_topic_for_subscribe()), - pipeline_constant.TWIN_PATCHES: ( - mqtt_topic_iothub.get_twin_patch_topic_for_subscribe() - ), - } + def _get_feature_subscription_topic(self, feature): + if feature == pipeline_constant.C2D_MSG: + return mqtt_topic_iothub.get_c2d_topic_for_subscribe( + self.pipeline_root.pipeline_configuration.device_id + ) + elif feature == pipeline_constant.INPUT_MSG: + return mqtt_topic_iothub.get_input_topic_for_subscribe( + self.pipeline_root.pipeline_configuration.device_id, + self.pipeline_root.pipeline_configuration.module_id, + ) + elif feature == pipeline_constant.METHODS: + return mqtt_topic_iothub.get_method_topic_for_subscribe() + elif feature == pipeline_constant.TWIN: + return mqtt_topic_iothub.get_twin_response_topic_for_subscribe() + elif feature == pipeline_constant.TWIN_PATCHES: + return mqtt_topic_iothub.get_twin_patch_topic_for_subscribe() + else: + logger.error("Cannot retrieve MQTT topic for subscription to invalid feature") + raise pipeline_exceptions.OperationError( + "Trying to enable/disable invalid feature - {}".format(feature) + ) @pipeline_thread.runs_on_pipeline_thread def _handle_pipeline_event(self, event): @@ -217,15 +167,19 @@ def _handle_pipeline_event(self, event): Pipeline Event handler function to convert incoming MQTT messages into the appropriate IoTHub events, based on the topic of the message """ + # TODO: should we always be decoding the payload? Seems strange to only sometimes do it. + # Is there value to the user getting the original bytestring from the wire? if isinstance(event, pipeline_events_mqtt.IncomingMQTTMessageEvent): topic = event.topic + device_id = self.pipeline_root.pipeline_configuration.device_id + module_id = self.pipeline_root.pipeline_configuration.module_id - if mqtt_topic_iothub.is_c2d_topic(topic, self.device_id): + if mqtt_topic_iothub.is_c2d_topic(topic, device_id): message = Message(event.payload) mqtt_topic_iothub.extract_message_properties_from_topic(topic, message) self.send_event_up(pipeline_events_iothub.C2DMessageEvent(message)) - elif mqtt_topic_iothub.is_input_topic(topic, self.device_id, self.module_id): + elif mqtt_topic_iothub.is_input_topic(topic, device_id, module_id): message = Message(event.payload) mqtt_topic_iothub.extract_message_properties_from_topic(topic, message) input_name = mqtt_topic_iothub.get_input_name_from_topic(topic) @@ -263,4 +217,4 @@ def _handle_pipeline_event(self, event): else: # all other messages get passed up - super(IoTHubMQTTTranslationStage, self)._handle_pipeline_event(event) + self.send_event_up(event) diff --git a/azure-iot-device/azure/iot/device/iothub/sync_clients.py b/azure-iot-device/azure/iot/device/iothub/sync_clients.py index e4f45bda9..1d8b73c9a 100644 --- a/azure-iot-device/azure/iot/device/iothub/sync_clients.py +++ b/azure-iot-device/azure/iot/device/iothub/sync_clients.py @@ -478,7 +478,7 @@ def send_message_to_output(self, message, output_name): logger.info("Sending message to output:" + output_name + "...") callback = EventedCallback() - self._mqtt_pipeline.send_output_event(message, callback=callback) + self._mqtt_pipeline.send_output_message(message, callback=callback) handle_result(callback) logger.info("Successfully sent message to output: " + output_name) diff --git a/azure-iot-device/azure/iot/device/product_info.py b/azure-iot-device/azure/iot/device/product_info.py deleted file mode 100644 index c8b735b12..000000000 --- a/azure-iot-device/azure/iot/device/product_info.py +++ /dev/null @@ -1,50 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import platform -from azure.iot.device.constant import VERSION, IOTHUB_IDENTIFIER, PROVISIONING_IDENTIFIER - -python_runtime = platform.python_version() -os_type = platform.system() -os_release = platform.version() -architecture = platform.machine() - - -class ProductInfo(object): - """ - A class for creating product identifiers or agent strings for IotHub as well as Provisioning. - """ - - @staticmethod - def _get_common_user_agent(): - return "({python_runtime};{os_type} {os_release};{architecture})".format( - python_runtime=python_runtime, - os_type=os_type, - os_release=os_release, - architecture=architecture, - ) - - @staticmethod - def get_iothub_user_agent(): - """ - Create the user agent for IotHub - """ - return "{iothub_iden}/{version}{common}".format( - iothub_iden=IOTHUB_IDENTIFIER, - version=VERSION, - common=ProductInfo._get_common_user_agent(), - ) - - @staticmethod - def get_provisioning_user_agent(): - """ - Create the user agent for Provisioning - """ - return "{provisioning_iden}/{version}{common}".format( - provisioning_iden=PROVISIONING_IDENTIFIER, - version=VERSION, - common=ProductInfo._get_common_user_agent(), - ) diff --git a/azure-iot-device/azure/iot/device/provisioning/abstract_provisioning_device_client.py b/azure-iot-device/azure/iot/device/provisioning/abstract_provisioning_device_client.py index 456d197f4..e269745db 100644 --- a/azure-iot-device/azure/iot/device/provisioning/abstract_provisioning_device_client.py +++ b/azure-iot-device/azure/iot/device/provisioning/abstract_provisioning_device_client.py @@ -11,7 +11,10 @@ import abc import six import logging -from azure.iot.device.provisioning import pipeline, security +from azure.iot.device.provisioning import pipeline + +from azure.iot.device.common.auth import sastoken as st +from azure.iot.device.common import auth logger = logging.getLogger(__name__) @@ -20,20 +23,26 @@ def _validate_kwargs(**kwargs): """Helper function to validate user provided kwargs. Raises TypeError if an invalid option has been provided""" # TODO: add support for server_verification_cert - valid_kwargs = ["websockets", "cipher"] + valid_kwargs = ["websockets", "cipher", "proxy_options"] for kwarg in kwargs: if kwarg not in valid_kwargs: raise TypeError("Got an unexpected keyword argument '{}'".format(kwarg)) +def _form_sas_uri(id_scope, registration_id): + return "{id_scope}/registrations/{registration_id}".format( + id_scope=id_scope, registration_id=registration_id + ) + + @six.add_metaclass(abc.ABCMeta) class AbstractProvisioningDeviceClient(object): """ Super class for any client that can be used to register devices to Device Provisioning Service. """ - def __init__(self, provisioning_pipeline): + def __init__(self, pipeline): """ Initializes the provisioning client. @@ -41,10 +50,10 @@ def __init__(self, provisioning_pipeline): Instead, the class methods that start with `create_from_` should be used to create a client object. - :param provisioning_pipeline: Instance of the provisioning pipeline object. - :type provisioning_pipeline: :class:`azure.iot.device.provisioning.pipeline.ProvisioningPipeline` + :param pipeline: Instance of the provisioning pipeline object. + :type pipeline: :class:`azure.iot.device.provisioning.pipeline.MQTTPipeline` """ - self._provisioning_pipeline = provisioning_pipeline + self._pipeline = pipeline self._provisioning_payload = None @classmethod @@ -83,18 +92,31 @@ def create_from_symmetric_key( :returns: A ProvisioningDeviceClient instance which can register via Symmetric Key. """ + # Ensure no invalid kwargs were passed by the user _validate_kwargs(**kwargs) - security_client = security.SymmetricKeySecurityClient( - provisioning_host=provisioning_host, + # Create SasToken + uri = _form_sas_uri(id_scope=id_scope, registration_id=registration_id) + signing_mechanism = auth.SymmetricKeySigningMechanism(key=symmetric_key) + try: + sastoken = st.SasToken(uri, signing_mechanism) + except st.SasTokenError as e: + new_err = ValueError("Could not create a SasToken using the provided values") + new_err.__cause__ = e + raise new_err + + # Pipeline Config setup + pipeline_configuration = pipeline.ProvisioningPipelineConfig( + hostname=provisioning_host, registration_id=registration_id, id_scope=id_scope, - symmetric_key=symmetric_key, - ) - pipeline_configuration = pipeline.ProvisioningPipelineConfig(**kwargs) - mqtt_provisioning_pipeline = pipeline.ProvisioningPipeline( - security_client, pipeline_configuration + sastoken=sastoken, + **kwargs ) + + # Pipeline setup + mqtt_provisioning_pipeline = pipeline.MQTTPipeline(pipeline_configuration) + return cls(mqtt_provisioning_pipeline) @classmethod @@ -131,18 +153,21 @@ def create_from_x509_certificate( :returns: A ProvisioningDeviceClient which can register via Symmetric Key. """ + # Ensure no invalid kwargs were passed by the user _validate_kwargs(**kwargs) - security_client = security.X509SecurityClient( - provisioning_host=provisioning_host, + # Pipeline Config setup + pipeline_configuration = pipeline.ProvisioningPipelineConfig( + hostname=provisioning_host, registration_id=registration_id, id_scope=id_scope, x509=x509, + **kwargs ) - pipeline_configuration = pipeline.ProvisioningPipelineConfig(**kwargs) - mqtt_provisioning_pipeline = pipeline.ProvisioningPipeline( - security_client, pipeline_configuration - ) + + # Pipeline setup + mqtt_provisioning_pipeline = pipeline.MQTTPipeline(pipeline_configuration) + return cls(mqtt_provisioning_pipeline) @abc.abstractmethod diff --git a/azure-iot-device/azure/iot/device/provisioning/aio/async_provisioning_device_client.py b/azure-iot-device/azure/iot/device/provisioning/aio/async_provisioning_device_client.py index d895b6bee..e3a2e210c 100644 --- a/azure-iot-device/azure/iot/device/provisioning/aio/async_provisioning_device_client.py +++ b/azure-iot-device/azure/iot/device/provisioning/aio/async_provisioning_device_client.py @@ -68,10 +68,10 @@ async def register(self): """ logger.info("Registering with Provisioning Service...") - if not self._provisioning_pipeline.responses_enabled[dps_constant.REGISTER]: + if not self._pipeline.responses_enabled[dps_constant.REGISTER]: await self._enable_responses() - register_async = async_adapter.emulate_async(self._provisioning_pipeline.register) + register_async = async_adapter.emulate_async(self._pipeline.register) register_complete = async_adapter.AwaitableCallback(return_arg_name="result") await register_async(payload=self._provisioning_payload, callback=register_complete) @@ -84,7 +84,7 @@ async def _enable_responses(self): """Enable to receive responses from Device Provisioning Service. """ logger.info("Enabling reception of response from Device Provisioning Service...") - subscribe_async = async_adapter.emulate_async(self._provisioning_pipeline.enable_responses) + subscribe_async = async_adapter.emulate_async(self._pipeline.enable_responses) subscription_complete = async_adapter.AwaitableCallback() await subscribe_async(callback=subscription_complete) diff --git a/azure-iot-device/azure/iot/device/provisioning/pipeline/__init__.py b/azure-iot-device/azure/iot/device/provisioning/pipeline/__init__.py index 3a9ac5918..680f86011 100644 --- a/azure-iot-device/azure/iot/device/provisioning/pipeline/__init__.py +++ b/azure-iot-device/azure/iot/device/provisioning/pipeline/__init__.py @@ -4,5 +4,5 @@ INTERNAL USAGE ONLY """ -from .provisioning_pipeline import ProvisioningPipeline +from .mqtt_pipeline import MQTTPipeline from .config import ProvisioningPipelineConfig diff --git a/azure-iot-device/azure/iot/device/provisioning/pipeline/config.py b/azure-iot-device/azure/iot/device/provisioning/pipeline/config.py index 2a6342c04..1ee6c7820 100644 --- a/azure-iot-device/azure/iot/device/provisioning/pipeline/config.py +++ b/azure-iot-device/azure/iot/device/provisioning/pipeline/config.py @@ -14,4 +14,17 @@ class ProvisioningPipelineConfig(BasePipelineConfig): """A class for storing all configurations/options for Provisioning clients in the Azure IoT Python Device Client Library. """ - pass + def __init__(self, hostname, registration_id, id_scope, **kwargs): + """Initializer for ProvisioningPipelineConfig which passes all unrecognized keyword-args down to BasePipelineConfig + to be evaluated. This stacked options setting is to allow for unique configuration options to exist between the + multiple clients, while maintaining a base configuration class with shared config options. + + :param str hostname: The hostname of the Provisioning hub instance to connect to + :param str registration_id: The device registration identity being provisioned + :param str id_scope: The identity of the provisoning service being used + """ + super(ProvisioningPipelineConfig, self).__init__(hostname=hostname, **kwargs) + + # Provisioning Connection Details + self.registration_id = registration_id + self.id_scope = id_scope diff --git a/azure-iot-device/azure/iot/device/provisioning/pipeline/provisioning_pipeline.py b/azure-iot-device/azure/iot/device/provisioning/pipeline/mqtt_pipeline.py similarity index 89% rename from azure-iot-device/azure/iot/device/provisioning/pipeline/provisioning_pipeline.py rename to azure-iot-device/azure/iot/device/provisioning/pipeline/mqtt_pipeline.py index 76229554d..a42e5e274 100644 --- a/azure-iot-device/azure/iot/device/provisioning/pipeline/provisioning_pipeline.py +++ b/azure-iot-device/azure/iot/device/provisioning/pipeline/mqtt_pipeline.py @@ -14,14 +14,13 @@ pipeline_stages_provisioning_mqtt, ) from azure.iot.device.provisioning.pipeline import pipeline_ops_provisioning -from azure.iot.device.provisioning.security import SymmetricKeySecurityClient, X509SecurityClient from azure.iot.device.provisioning.pipeline import constant as provisioning_constants logger = logging.getLogger(__name__) -class ProvisioningPipeline(object): - def __init__(self, security_client, pipeline_configuration): +class MQTTPipeline(object): + def __init__(self, pipeline_configuration): """ Constructor for instantiating a pipeline :param security_client: The security client which stores credentials @@ -32,7 +31,7 @@ def __init__(self, security_client, pipeline_configuration): self.on_connected = None self.on_disconnected = None self.on_message_received = None - self._registration_id = security_client.registration_id + self._registration_id = pipeline_configuration.registration_id self._pipeline = ( # @@ -40,10 +39,10 @@ def __init__(self, security_client, pipeline_configuration): # pipeline_stages_base.PipelineRootStage(pipeline_configuration=pipeline_configuration) # - # UseSecurityClientStager comes near the root by default because it doesn't need to be after - # anything, but it does need to be before ProvisoningMQTTTranslationStage. + # SasTokenRenewalStage comes near the root by default because it should be as close + # to the top of the pipeline as possible, and does not need to be after anything. # - .append_stage(pipeline_stages_provisioning.UseSecurityClientStage()) + .append_stage(pipeline_stages_base.SasTokenRenewalStage()) # # RegistrationStage needs to come early because this is the stage that converts registration # or query requests into request and response objects which are used by later stages @@ -116,17 +115,7 @@ def _on_disconnected(): self._pipeline.on_disconnected_handler = _on_disconnected callback = EventedCallback() - - if isinstance(security_client, X509SecurityClient): - op = pipeline_ops_provisioning.SetX509SecurityClientOperation( - security_client=security_client, callback=callback - ) - elif isinstance(security_client, SymmetricKeySecurityClient): - op = pipeline_ops_provisioning.SetSymmetricKeySecurityClientOperation( - security_client=security_client, callback=callback - ) - else: - logger.error("Provisioning not equipped to handle other security client.") + op = pipeline_ops_base.InitializePipelineOperation(callback=callback) self._pipeline.run_op(op) callback.wait_for_completion() diff --git a/azure-iot-device/azure/iot/device/provisioning/pipeline/pipeline_ops_provisioning.py b/azure-iot-device/azure/iot/device/provisioning/pipeline/pipeline_ops_provisioning.py index eb30f6d1a..b16a72a16 100644 --- a/azure-iot-device/azure/iot/device/provisioning/pipeline/pipeline_ops_provisioning.py +++ b/azure-iot-device/azure/iot/device/provisioning/pipeline/pipeline_ops_provisioning.py @@ -6,91 +6,6 @@ from azure.iot.device.common.pipeline.pipeline_ops_base import PipelineOperation -class SetSymmetricKeySecurityClientOperation(PipelineOperation): - """ - A PipelineOperation object which tells the pipeline to use a symmetric key security client. - Some pipeline stage is expected to extract arguments out of the security client and pass them - on so an even lower stage can use those arguments to connect. - - This operation is in the group of provisioning operations because security clients are currently - very provisioning-specific - """ - - def __init__(self, security_client, callback): - """ - Initializer for SetSecurityClient. - - :param object security_client: The security client object to use to retrieve connection parameters - which can be used to connect to the service. - :param Function callback: The function that gets called when this operation is complete or has failed. - The callback function must accept A PipelineOperation object which indicates the specific operation which - has completed or failed. - """ - super(SetSymmetricKeySecurityClientOperation, self).__init__(callback=callback) - self.security_client = security_client - - -class SetX509SecurityClientOperation(PipelineOperation): - """ - A PipelineOperation object which contains connection arguments which were retrieved from a - X509 security client likely by a pipeline stage which handles the - SetX509SecurityClientOperation operation. - - This operation is in the group of Provisioning operations because the arguments which it accepts are - very specific to DPS connections and would not apply to other types of client connections - (such as a Provisioning client). - """ - - def __init__(self, security_client, callback): - """ - Initializer for SetSecurityClient. - - :param object security_client: The security client object to use to retrieve connection parameters - which can be used to connect to the service. - :param Function callback: The function that gets called when this operation is complete or has failed. - The callback function must accept A PipelineOperation object which indicates the specific operation which - has completed or failed. - """ - super(SetX509SecurityClientOperation, self).__init__(callback=callback) - self.security_client = security_client - - -class SetProvisioningClientConnectionArgsOperation(PipelineOperation): - """ - A PipelineOperation object which contains connection arguments which were retrieved from a - symmetric key or a X509 security client likely by a pipeline stage which handles the - SetSymmetricKeySecurityClientOperation or SetX509SecurityClientOperation operation. - - This operation is in the group of Provisioning operations because the arguments which it accepts are - very specific to DPS connections and would not apply to other types of client connections - (such as a Provisioning client). - """ - - def __init__( - self, - provisioning_host, - registration_id, - id_scope, - callback, - client_cert=None, - sas_token=None, - ): - """ - Initializer for SetProvisioningClientConnectionArgsOperation. - :param registration_id: The registration ID is used to uniquely identify a device in the Device Provisioning Service. - The registration ID is alphanumeric, lowercase string and may contain hyphens. - :param id_scope: The ID scope is used to uniquely identify the specific provisioning service the device will - register through. The ID scope is assigned to a Device Provisioning Service when it is created by the user and - is generated by the service and is immutable, guaranteeing uniqueness. - """ - super(SetProvisioningClientConnectionArgsOperation, self).__init__(callback=callback) - self.provisioning_host = provisioning_host - self.registration_id = registration_id - self.id_scope = id_scope - self.client_cert = client_cert - self.sas_token = sas_token - - class RegisterOperation(PipelineOperation): """ A PipelineOperation object which contains arguments used to send a registration request diff --git a/azure-iot-device/azure/iot/device/provisioning/pipeline/pipeline_stages_provisioning.py b/azure-iot-device/azure/iot/device/provisioning/pipeline/pipeline_stages_provisioning.py index 710f11eaf..4d66c2043 100644 --- a/azure-iot-device/azure/iot/device/provisioning/pipeline/pipeline_stages_provisioning.py +++ b/azure-iot-device/azure/iot/device/provisioning/pipeline/pipeline_stages_provisioning.py @@ -22,43 +22,6 @@ logger = logging.getLogger(__name__) -class UseSecurityClientStage(PipelineStage): - """ - PipelineStage which extracts relevant SecurityClient values for a new - SetProvisioningClientConnectionArgsOperation. - - All other operations are passed down. - """ - - @pipeline_thread.runs_on_pipeline_thread - def _run_op(self, op): - if isinstance(op, pipeline_ops_provisioning.SetSymmetricKeySecurityClientOperation): - - security_client = op.security_client - worker_op = op.spawn_worker_op( - worker_op_type=pipeline_ops_provisioning.SetProvisioningClientConnectionArgsOperation, - provisioning_host=security_client.provisioning_host, - registration_id=security_client.registration_id, - id_scope=security_client.id_scope, - sas_token=security_client.get_current_sas_token(), - ) - self.send_op_down(worker_op) - - elif isinstance(op, pipeline_ops_provisioning.SetX509SecurityClientOperation): - security_client = op.security_client - worker_op = op.spawn_worker_op( - worker_op_type=pipeline_ops_provisioning.SetProvisioningClientConnectionArgsOperation, - provisioning_host=security_client.provisioning_host, - registration_id=security_client.registration_id, - id_scope=security_client.id_scope, - client_cert=security_client.get_x509_certificate(), - ) - self.send_op_down(worker_op) - - else: - super(UseSecurityClientStage, self)._run_op(op) - - class CommonProvisioningStage(PipelineStage): """ This is a super stage that the RegistrationStage and PollingStatusStage of @@ -180,7 +143,7 @@ def _process_failed_and_assigned_registration_status( original_provisioning_op.registration_result = complete_registration_result if registration_status == "failed": error = exceptions.ServiceError( - "Query Status operation returned a failed registration status with a status code of {status_code}".format( + "Query Status operation returned a failed registration status with a status code of {status_code}".format( status_code=request_response_op.status_code ) ) diff --git a/azure-iot-device/azure/iot/device/provisioning/pipeline/pipeline_stages_provisioning_mqtt.py b/azure-iot-device/azure/iot/device/provisioning/pipeline/pipeline_stages_provisioning_mqtt.py index 432a39bcf..101edfba0 100644 --- a/azure-iot-device/azure/iot/device/provisioning/pipeline/pipeline_stages_provisioning_mqtt.py +++ b/azure-iot-device/azure/iot/device/provisioning/pipeline/pipeline_stages_provisioning_mqtt.py @@ -13,13 +13,14 @@ pipeline_events_mqtt, pipeline_thread, pipeline_events_base, + pipeline_exceptions, ) from azure.iot.device.common.pipeline.pipeline_stages_base import PipelineStage from azure.iot.device.provisioning.pipeline import mqtt_topic_provisioning from azure.iot.device.provisioning.pipeline import pipeline_ops_provisioning from azure.iot.device import constant as pkg_constant from . import constant as pipeline_constant -from azure.iot.device.product_info import ProductInfo +from azure.iot.device import user_agent logger = logging.getLogger(__name__) @@ -37,34 +38,27 @@ def __init__(self): @pipeline_thread.runs_on_pipeline_thread def _run_op(self, op): - if isinstance(op, pipeline_ops_provisioning.SetProvisioningClientConnectionArgsOperation): - # get security client args from above, save some, use some to build topic names, - # always pass it down because MQTT protocol stage will also want to receive these args. + if isinstance(op, pipeline_ops_base.InitializePipelineOperation): - client_id = op.registration_id + client_id = self.pipeline_root.pipeline_configuration.registration_id query_param_seq = [ ("api-version", pkg_constant.PROVISIONING_API_VERSION), - ("ClientVersion", ProductInfo.get_provisioning_user_agent()), + ("ClientVersion", user_agent.get_provisioning_user_agent()), ] username = "{id_scope}/registrations/{registration_id}/{query_params}".format( - id_scope=op.id_scope, - registration_id=op.registration_id, + id_scope=self.pipeline_root.pipeline_configuration.id_scope, + registration_id=self.pipeline_root.pipeline_configuration.registration_id, query_params=version_compat.urlencode( query_param_seq, quote_via=urllib.parse.quote ), ) - hostname = op.provisioning_host + # Dynamically attach the derived MQTT values to the InitalizePipelineOperation + # to be used later down the pipeline + op.username = username + op.client_id = client_id - worker_op = op.spawn_worker_op( - worker_op_type=pipeline_ops_mqtt.SetMQTTConnectionArgsOperation, - client_id=client_id, - hostname=hostname, - username=username, - client_cert=op.client_cert, - sas_token=op.sas_token, - ) - self.send_op_down(worker_op) + self.send_op_down(op) elif isinstance(op, pipeline_ops_base.RequestOperation): if op.request_type == pipeline_constant.REGISTER: @@ -77,7 +71,7 @@ def _run_op(self, op): payload=op.request_body, ) self.send_op_down(worker_op) - else: + elif op.request_type == pipeline_constant.QUERY: topic = mqtt_topic_provisioning.get_query_topic_for_publish( request_id=op.request_id, operation_id=op.query_params["operation_id"] ) @@ -87,8 +81,17 @@ def _run_op(self, op): payload=op.request_body, ) self.send_op_down(worker_op) + else: + raise pipeline_exceptions.OperationError( + "RequestOperation request_type {} not supported".format(op.request_type) + ) elif isinstance(op, pipeline_ops_base.EnableFeatureOperation): + # The only supported feature is REGISTER + if not op.feature_name == pipeline_constant.REGISTER: + raise pipeline_exceptions.OperationError( + "Trying to enable/disable invalid feature - {}".format(op.feature_name) + ) # Enabling for register gets translated into an MQTT subscribe operation topic = mqtt_topic_provisioning.get_register_topic_for_subscribe() worker_op = op.spawn_worker_op( @@ -97,6 +100,11 @@ def _run_op(self, op): self.send_op_down(worker_op) elif isinstance(op, pipeline_ops_base.DisableFeatureOperation): + # The only supported feature is REGISTER + if not op.feature_name == pipeline_constant.REGISTER: + raise pipeline_exceptions.OperationError( + "Trying to enable/disable invalid feature - {}".format(op.feature_name) + ) # Disabling a register response gets turned into an MQTT unsubscribe operation topic = mqtt_topic_provisioning.get_register_topic_for_subscribe() worker_op = op.spawn_worker_op( @@ -146,4 +154,4 @@ def _handle_pipeline_event(self, event): else: # all other messages get passed up - super(ProvisioningMQTTTranslationStage, self)._handle_pipeline_event(event) + self.send_event_up(event) diff --git a/azure-iot-device/azure/iot/device/provisioning/provisioning_device_client.py b/azure-iot-device/azure/iot/device/provisioning/provisioning_device_client.py index fc94bb9f2..81d78b3e8 100644 --- a/azure-iot-device/azure/iot/device/provisioning/provisioning_device_client.py +++ b/azure-iot-device/azure/iot/device/provisioning/provisioning_device_client.py @@ -65,14 +65,12 @@ def register(self): """ logger.info("Registering with Provisioning Service...") - if not self._provisioning_pipeline.responses_enabled[dps_constant.REGISTER]: + if not self._pipeline.responses_enabled[dps_constant.REGISTER]: self._enable_responses() register_complete = EventedCallback(return_arg_name="result") - self._provisioning_pipeline.register( - payload=self._provisioning_payload, callback=register_complete - ) + self._pipeline.register(payload=self._provisioning_payload, callback=register_complete) result = handle_result(register_complete) @@ -89,7 +87,7 @@ def _enable_responses(self): logger.info("Enabling reception of response from Device Provisioning Service...") subscription_complete = EventedCallback() - self._provisioning_pipeline.enable_responses(callback=subscription_complete) + self._pipeline.enable_responses(callback=subscription_complete) handle_result(subscription_complete) diff --git a/azure-iot-device/azure/iot/device/provisioning/security/__init__.py b/azure-iot-device/azure/iot/device/provisioning/security/__init__.py deleted file mode 100644 index d71c1b4de..000000000 --- a/azure-iot-device/azure/iot/device/provisioning/security/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -"""Azure Provisioning Device Security - -This package provides security and authentication-related functionality for use with the -Azure Provisioning Device SDK. -""" - -from .sk_security_client import SymmetricKeySecurityClient -from .x509_security_client import X509SecurityClient diff --git a/azure-iot-device/azure/iot/device/provisioning/security/sk_security_client.py b/azure-iot-device/azure/iot/device/provisioning/security/sk_security_client.py deleted file mode 100644 index 69aa6a983..000000000 --- a/azure-iot-device/azure/iot/device/provisioning/security/sk_security_client.py +++ /dev/null @@ -1,87 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -"""This module contains a client that is responsible for providing shared access tokens that will eventually establish - the authenticity of devices to Device Provisioning Service. -""" - -from azure.iot.device.common.sastoken import SasToken - - -class SymmetricKeySecurityClient(object): - """ - A client that is responsible for providing shared access tokens that will eventually establish - the authenticity of devices to Device Provisioning Service. - :ivar provisioning_host: Host running the Device Provisioning Service - :ivar registration_id: : The registration ID is used to uniquely identify a device in the Device Provisioning Service. - :ivar id_scope: : The ID scope is used to uniquely identify the specific provisioning service the device will - register through. - """ - - def __init__(self, provisioning_host, registration_id, id_scope, symmetric_key): - """ - Initialize the symmetric key security client. - :param provisioning_host: Host running the Device Provisioning Service. Can be found in the Azure portal in the - Overview tab as the string Global device endpoint - :param registration_id: The registration ID is used to uniquely identify a device in the Device Provisioning Service. - The registration ID is alphanumeric, lowercase string and may contain hyphens. - :param id_scope: The ID scope is used to uniquely identify the specific provisioning service the device will - register through. The ID scope is assigned to a Device Provisioning Service when it is created by the user and - is generated by the service and is immutable, guaranteeing uniqueness. - :param symmetric_key: The key which will be used to create the shared access signature token to authenticate - the device with the Device Provisioning Service. By default, the Device Provisioning Service creates - new symmetric keys with a default length of 32 bytes when new enrollments are saved with the Auto-generate keys - option enabled. Users can provide their own symmetric keys for enrollments by disabling this option within - 16 bytes and 64 bytes and in valid Base64 format. - """ - self._provisioning_host = provisioning_host - self._registration_id = registration_id - self._id_scope = id_scope - self._symmetric_key = symmetric_key - self._sas_token = None - - @property - def provisioning_host(self): - """ - :return: The registration ID is used to uniquely identify a device in the Device Provisioning Service. - The registration ID is alphanumeric, lowercase string and may contain hyphens. - """ - return self._provisioning_host - - @property - def registration_id(self): - """ - :return: The registration ID is used to uniquely identify a device in the Device Provisioning Service. - The registration ID is alphanumeric, lowercase string and may contain hyphens. - """ - return self._registration_id - - @property - def id_scope(self): - """ - :return: Host running the Device Provisioning Service. - """ - return self._id_scope - - def _create_shared_access_signature(self): - """ - Construct SAS tokens that have a hashed signature formed using the symmetric key of this security client. - This signature is recreated by the Device Provisioning Service to verify whether a security token presented - during attestation is authentic or not. - :return: A string representation of the shared access signature which is of the form - SharedAccessSignature sig={signature}&se={expiry}&skn={policyName}&sr={URL-encoded-resourceURI} - """ - uri = self._id_scope + "/registrations/" + self._registration_id - key = self._symmetric_key - time_to_live = 3600 - keyname = "registration" - return SasToken(uri, key, keyname, time_to_live) - - def get_current_sas_token(self): - if self._sas_token is None: - self._sas_token = self._create_shared_access_signature() - else: - self._sas_token.refresh() - return str(self._sas_token) diff --git a/azure-iot-device/azure/iot/device/provisioning/security/x509_security_client.py b/azure-iot-device/azure/iot/device/provisioning/security/x509_security_client.py deleted file mode 100644 index 766da320c..000000000 --- a/azure-iot-device/azure/iot/device/provisioning/security/x509_security_client.py +++ /dev/null @@ -1,64 +0,0 @@ -# -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -"""This module contains a client that is responsible for providing x509 certificates that will eventually establish - the authenticity of devices to Device Provisioning Service. -""" - - -class X509SecurityClient(object): - """ - An X509 Security Client. This uses the certificate and key provided to authenticate a device - with an Azure DPS instance.X509 Authentication is only supported for device identities - connecting directly to an Azure DPS. - """ - - def __init__(self, provisioning_host, registration_id, id_scope, x509): - """ - Initialize the X509 Certificate security client. - :param provisioning_host: Host running the Device Provisioning Service. Can be found in the Azure portal in the - Overview tab as the string Global device endpoint - :param registration_id: The registration ID is used to uniquely identify a device in the Device Provisioning Service. - The registration ID is alphanumeric, lowercase string and may contain hyphens. - :param id_scope: The ID scope is used to uniquely identify the specific provisioning service the device will - register through. The ID scope is assigned to a Device Provisioning Service when it is created by the user and - is generated by the service and is immutable, guaranteeing uniqueness. - :param x509: The x509 certificate, To use the certificate the enrollment object needs to contain cert (either the root certificate or one of the intermediate CA certificates). - If the cert comes from a CER file, it needs to be base64 encoded. - """ - self._provisioning_host = provisioning_host - self._registration_id = registration_id - self._id_scope = id_scope - self._x509 = x509 - - @property - def provisioning_host(self): - """ - :return: The registration ID is used to uniquely identify a device in the Device Provisioning Service. - The registration ID is alphanumeric, lowercase string and may contain hyphens. - """ - return self._provisioning_host - - @property - def registration_id(self): - """ - :return: The registration ID is used to uniquely identify a device in the Device Provisioning Service. - The registration ID is alphanumeric, lowercase string and may contain hyphens. - """ - return self._registration_id - - @property - def id_scope(self): - """ - :return: Host running the Device Provisioning Service. - """ - return self._id_scope - - def get_x509_certificate(self): - """ - :return: The x509 certificate, To use the certificate the enrollment object needs to contain - cert (either the root certificate or one of the intermediate CA certificates). - """ - return self._x509 diff --git a/azure-iot-device/azure/iot/device/user_agent.py b/azure-iot-device/azure/iot/device/user_agent.py new file mode 100644 index 000000000..a07615629 --- /dev/null +++ b/azure-iot-device/azure/iot/device/user_agent.py @@ -0,0 +1,41 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +"""This module is for creating agent strings for all clients""" + +import platform +from azure.iot.device.constant import VERSION, IOTHUB_IDENTIFIER, PROVISIONING_IDENTIFIER + +python_runtime = platform.python_version() +os_type = platform.system() +os_release = platform.version() +architecture = platform.machine() + + +def _get_common_user_agent(): + return "({python_runtime};{os_type} {os_release};{architecture})".format( + python_runtime=python_runtime, + os_type=os_type, + os_release=os_release, + architecture=architecture, + ) + + +def get_iothub_user_agent(): + """ + Create the user agent for IotHub + """ + return "{iothub_iden}/{version}{common}".format( + iothub_iden=IOTHUB_IDENTIFIER, version=VERSION, common=_get_common_user_agent() + ) + + +def get_provisioning_user_agent(): + """ + Create the user agent for Provisioning + """ + return "{provisioning_iden}/{version}{common}".format( + provisioning_iden=PROVISIONING_IDENTIFIER, version=VERSION, common=_get_common_user_agent() + ) diff --git a/azure-iot-device/tests/common/test_connection_string.py b/azure-iot-device/tests/common/auth/test_connection_string.py similarity index 82% rename from azure-iot-device/tests/common/test_connection_string.py rename to azure-iot-device/tests/common/auth/test_connection_string.py index 15a3f77b5..566559660 100644 --- a/azure-iot-device/tests/common/test_connection_string.py +++ b/azure-iot-device/tests/common/auth/test_connection_string.py @@ -6,7 +6,8 @@ import pytest import logging -from azure.iot.device.common.connection_string import ConnectionString +import six +from azure.iot.device.common.auth.connection_string import ConnectionString logging.basicConfig(level=logging.DEBUG) @@ -43,7 +44,7 @@ def test_instantiates_correctly_from_string(self, input_string): cs = ConnectionString(input_string) assert isinstance(cs, ConnectionString) - @pytest.mark.it("Raises ValueError on bad input") + @pytest.mark.it("Raises ValueError on invalid string input during instantiation") @pytest.mark.parametrize( "input_string", [ @@ -60,10 +61,30 @@ def test_instantiates_correctly_from_string(self, input_string): ), ], ) - def test_raises_value_error_on_bad_input(self, input_string): + def test_raises_value_error_on_invalid_input(self, input_string): with pytest.raises(ValueError): ConnectionString(input_string) + @pytest.mark.it("Raises TypeError on non-string input during instantiation") + @pytest.mark.parametrize( + "input_val", + [ + pytest.param(2123, id="Integer"), + pytest.param(23.098, id="Float"), + pytest.param( + b"bytes", + id="Bytes", + marks=pytest.mark.xfail(six.PY2, reason="Bytes are valid in Python 2.7"), + ), + pytest.param(object(), id="Complex object"), + pytest.param(["a", "b"], id="List"), + pytest.param({"a": "b"}, id="Dictionary"), + ], + ) + def test_raises_type_error_on_non_string_input(self, input_val): + with pytest.raises(TypeError): + ConnectionString(input_val) + @pytest.mark.it("Uses the input connection string as a string representation") def test_string_representation_of_object_is_the_input_string(self): string = "HostName=my.host.name;SharedAccessKeyName=mykeyname;SharedAccessKey=Zm9vYmFy" diff --git a/azure-iot-device/tests/common/auth/test_sastoken.py b/azure-iot-device/tests/common/auth/test_sastoken.py new file mode 100644 index 000000000..720657c22 --- /dev/null +++ b/azure-iot-device/tests/common/auth/test_sastoken.py @@ -0,0 +1,168 @@ +# -*- coding: utf-8 -*- +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import pytest +import time +import re +import logging +import six.moves.urllib as urllib +from azure.iot.device.common.auth.sastoken import SasToken, SasTokenError + +logging.basicConfig(level=logging.DEBUG) + +fake_uri = "some/resource/location" +fake_signed_data = "ajsc8nLKacIjGsYyB4iYDFCZaRMmmDrUuY5lncYDYPI=" +fake_key_name = "fakekeyname" + + +def token_parser(token_str): + """helper function that parses a token string for indvidual values""" + token_map = {} + kv_string = token_str.split(" ")[1] + kv_pairs = kv_string.split("&") + for kv in kv_pairs: + t = kv.split("=") + token_map[t[0]] = t[1] + return token_map + + +@pytest.fixture +def signing_mechanism(mocker): + mechanism = mocker.MagicMock() + mechanism.sign.return_value = fake_signed_data + return mechanism + + +# TODO: Rename this. These are not "device" and "service" tokens, the distinction is more generic +@pytest.fixture(params=["Device Token", "Service Token"]) +def sastoken(request, signing_mechanism): + token_type = request.param + if token_type == "Device Token": + return SasToken(uri=fake_uri, signing_mechanism=signing_mechanism) + elif token_type == "Service Token": + return SasToken(uri=fake_uri, signing_mechanism=signing_mechanism, key_name=fake_key_name) + + +@pytest.mark.describe("SasToken") +class TestSasToken(object): + @pytest.mark.it("Instantiates with a default TTL of 3600 seconds if no TTL is provided") + def test_default_ttl(self, signing_mechanism): + s = SasToken(fake_uri, signing_mechanism) + assert s.ttl == 3600 + + @pytest.mark.it("Instantiates with a custom TTL if provided") + def test_custom_ttl(self, signing_mechanism): + custom_ttl = 4747 + s = SasToken(fake_uri, signing_mechanism, ttl=custom_ttl) + assert s.ttl == custom_ttl + + @pytest.mark.it("Instantiates with with no key name by default if no key name is provided") + def test_default_key_name(self, signing_mechanism): + s = SasToken(fake_uri, signing_mechanism) + assert s._key_name is None + + @pytest.mark.it("Instantiates with the given key name if provided") + def test_custom_key_name(self, signing_mechanism): + s = SasToken(fake_uri, signing_mechanism, key_name=fake_key_name) + assert s._key_name == fake_key_name + + @pytest.mark.it( + "Instantiates with an expiry time TTL seconds in the future from the moment of instantiation" + ) + def test_expiry_time(self, mocker, signing_mechanism): + fake_current_time = 1000 + mocker.patch.object(time, "time", return_value=fake_current_time) + + s = SasToken(fake_uri, signing_mechanism) + assert s.expiry_time == fake_current_time + s.ttl + + @pytest.mark.it("Calls .refresh() to build the SAS token string on instantiation") + def test_refresh_on_instantiation(self, mocker, signing_mechanism): + refresh_mock = mocker.spy(SasToken, "refresh") + assert refresh_mock.call_count == 0 + SasToken(fake_uri, signing_mechanism) + assert refresh_mock.call_count == 1 + + @pytest.mark.it("Returns the SAS token string as the string representation of the object") + def test_str_rep(self, sastoken): + assert str(sastoken) == sastoken._token + + @pytest.mark.it( + "Maintains the .expiry_time attribute as a read-only property (raises AttributeError upon attempt)" + ) + def test_expiry_time_read_only(self, sastoken): + with pytest.raises(AttributeError): + sastoken.expiry_time = 12321312 + + +@pytest.mark.describe("SasToken - .refresh()") +class TestSasTokenRefresh(object): + @pytest.mark.it("Sets a new expiry time of TTL seconds in the future") + def test_new_expiry(self, mocker, sastoken): + fake_current_time = 1000 + mocker.patch.object(time, "time", return_value=fake_current_time) + sastoken.refresh() + assert sastoken.expiry_time == fake_current_time + sastoken.ttl + + # TODO: reflect url encoding here? + @pytest.mark.it( + "Uses the token's signing mechanism to create a signature by signing a concatenation of the (URL encoded) URI and updated expiry time" + ) + def test_generate_new_token(self, mocker, signing_mechanism, sastoken): + old_token_str = str(sastoken) + fake_future_time = 1000 + mocker.patch.object(time, "time", return_value=fake_future_time) + signing_mechanism.reset_mock() + fake_signature = "new_fake_signature" + signing_mechanism.sign.return_value = fake_signature + + sastoken.refresh() + + # The token string has been updated + assert str(sastoken) != old_token_str + # The signing mechanism was used to sign a string + assert signing_mechanism.sign.call_count == 1 + # The string being signed was a concatenation of the URI and expiry time + assert signing_mechanism.sign.call_args == mocker.call( + urllib.parse.quote(sastoken._uri, safe="") + "\n" + str(sastoken.expiry_time) + ) + # The token string has the resulting signed string included as the signature + token_info = token_parser(str(sastoken)) + assert token_info["sig"] == fake_signature + + @pytest.mark.it( + "Builds a new token string using the token's URI (URL encoded) and expiry time, along with the signature created by the signing mechanism (also URL encoded)" + ) + def test_token_string(self, sastoken): + token_str = sastoken._token + + # Verify that token string representation matches token format + if not sastoken._key_name: + pattern = re.compile(r"SharedAccessSignature sr=(.+)&sig=(.+)&se=(.+)") + else: + pattern = re.compile(r"SharedAccessSignature sr=(.+)&sig=(.+)&se=(.+)&skn=(.+)") + assert pattern.match(token_str) + + # Verify that content in the string representation is correct + token_info = token_parser(token_str) + assert token_info["sr"] == urllib.parse.quote(sastoken._uri, safe="") + assert token_info["sig"] == urllib.parse.quote( + sastoken._signing_mechanism.sign.return_value, safe="" + ) + assert token_info["se"] == str(sastoken.expiry_time) + if sastoken._key_name: + assert token_info["skn"] == sastoken._key_name + + @pytest.mark.it("Raises a SasTokenError if an exception is raised by the signing mechanism") + def test_signing_mechanism_raises_value_error( + self, mocker, signing_mechanism, sastoken, arbitrary_exception + ): + signing_mechanism.sign.side_effect = arbitrary_exception + + with pytest.raises(SasTokenError) as e_info: + sastoken.refresh() + assert e_info.value.__cause__ is arbitrary_exception diff --git a/azure-iot-device/tests/common/auth/test_signing_mechanism.py b/azure-iot-device/tests/common/auth/test_signing_mechanism.py new file mode 100644 index 000000000..f0f8a6b76 --- /dev/null +++ b/azure-iot-device/tests/common/auth/test_signing_mechanism.py @@ -0,0 +1,128 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import pytest +import logging +import hmac +import hashlib +import base64 +from azure.iot.device.common.auth import SymmetricKeySigningMechanism + +logging.basicConfig(level=logging.DEBUG) + + +@pytest.mark.describe("SymmetricKeySigningMechanism - Instantiation") +class TestSymmetricKeySigningMechanismInstantiation(object): + @pytest.mark.it( + "Derives and stores the signing key from the provided symmetric key by base64 decoding it" + ) + @pytest.mark.parametrize( + "key, expected_signing_key", + [ + pytest.param( + "NMgJDvdKTxjLi+xBxxkDDEwDJxEvOE5u8BiT0mVgPeg=", + b"4\xc8\t\x0e\xf7JO\x18\xcb\x8b\xecA\xc7\x19\x03\x0cL\x03'\x11/8Nn\xf0\x18\x93\xd2e`=\xe8", + id="Example 1", + ), + pytest.param( + "zqtyZCGuKg/UHvSzgYnNod/uHChWrzGGtHSgPi4cC2U=", + b"\xce\xabrd!\xae*\x0f\xd4\x1e\xf4\xb3\x81\x89\xcd\xa1\xdf\xee\x1c(V\xaf1\x86\xb4t\xa0>.\x1c\x0be", + id="Example 2", + ), + ], + ) + def test_dervies_signing_key(self, key, expected_signing_key): + sm = SymmetricKeySigningMechanism(key) + assert sm._signing_key == expected_signing_key + + @pytest.mark.it("Supports symmetric keys in both string and byte formats") + @pytest.mark.parametrize( + "key, expected_signing_key", + [ + pytest.param( + "NMgJDvdKTxjLi+xBxxkDDEwDJxEvOE5u8BiT0mVgPeg=", + b"4\xc8\t\x0e\xf7JO\x18\xcb\x8b\xecA\xc7\x19\x03\x0cL\x03'\x11/8Nn\xf0\x18\x93\xd2e`=\xe8", + id="String", + ), + pytest.param( + b"NMgJDvdKTxjLi+xBxxkDDEwDJxEvOE5u8BiT0mVgPeg=", + b"4\xc8\t\x0e\xf7JO\x18\xcb\x8b\xecA\xc7\x19\x03\x0cL\x03'\x11/8Nn\xf0\x18\x93\xd2e`=\xe8", + id="Bytes", + ), + ], + ) + def test_supported_types(self, key, expected_signing_key): + sm = SymmetricKeySigningMechanism(key) + assert sm._signing_key == expected_signing_key + + @pytest.mark.it("Raises a ValueError if the provided symmetric key is invalid") + @pytest.mark.parametrize( + "key", + [pytest.param("not a key", id="Not a key"), pytest.param("YWJjx", id="Incomplete key")], + ) + def test_invalid_key(self, key): + with pytest.raises(ValueError): + SymmetricKeySigningMechanism(key) + + +@pytest.mark.describe("SymmetricKeySigningMechanism - .sign()") +class TestSymmetricKeySigningMechanismSign(object): + @pytest.fixture + def signing_mechanism(self): + return SymmetricKeySigningMechanism("NMgJDvdKTxjLi+xBxxkDDEwDJxEvOE5u8BiT0mVgPeg=") + + @pytest.mark.it( + "Generates an HMAC message digest from the signing key and provided data string, using the HMAC-SHA256 algorithm" + ) + def test_hmac(self, mocker, signing_mechanism): + hmac_mock = mocker.patch.object(hmac, "HMAC") + hmac_digest_mock = hmac_mock.return_value.digest + hmac_digest_mock.return_value = b"\xd2\x06\xf7\x12\xf1\xe9\x95$\x90\xfd\x12\x9a\xb1\xbe\xb4\xf8\xf3\xc4\x1ap\x8a\xab'\x8a.D\xfb\x84\x96\xca\xf3z" + + data_string = "sign this message" + signing_mechanism.sign(data_string) + + assert hmac_mock.call_count == 1 + assert hmac_mock.call_args == mocker.call( + key=signing_mechanism._signing_key, + msg=data_string.encode("utf-8"), + digestmod=hashlib.sha256, + ) + assert hmac_digest_mock.call_count == 1 + + @pytest.mark.it( + "Returns the base64 encoded HMAC message digest (converted to string) as the signed data" + ) + def test_b64encode(self, mocker, signing_mechanism): + hmac_mock = mocker.patch.object(hmac, "HMAC") + hmac_digest_mock = hmac_mock.return_value.digest + hmac_digest_mock.return_value = b"\xd2\x06\xf7\x12\xf1\xe9\x95$\x90\xfd\x12\x9a\xb1\xbe\xb4\xf8\xf3\xc4\x1ap\x8a\xab'\x8a.D\xfb\x84\x96\xca\xf3z" + + data_string = "sign this message" + signature = signing_mechanism.sign(data_string) + + assert signature == base64.b64encode(hmac_digest_mock.return_value).decode("utf-8") + + @pytest.mark.it("Supports data strings in both string and byte formats") + @pytest.mark.parametrize( + "data_string, expected_signature", + [ + pytest.param( + "sign this message", "8NJRMT83CcplGrAGaUVIUM/md5914KpWVNngSVoF9/M=", id="String" + ), + pytest.param( + b"sign this message", "8NJRMT83CcplGrAGaUVIUM/md5914KpWVNngSVoF9/M=", id="Bytes" + ), + ], + ) + def test_supported_types(self, signing_mechanism, data_string, expected_signature): + assert signing_mechanism.sign(data_string) == expected_signature + + @pytest.mark.it("Raises a ValueError if unable to sign the provided data string") + @pytest.mark.parametrize("data_string", [pytest.param(123, id="Integer input")]) + def test_bad_input(self, signing_mechanism, data_string): + with pytest.raises(ValueError): + signing_mechanism.sign(data_string) diff --git a/azure-iot-device/tests/common/pipeline/config_test.py b/azure-iot-device/tests/common/pipeline/config_test.py new file mode 100644 index 000000000..c11726393 --- /dev/null +++ b/azure-iot-device/tests/common/pipeline/config_test.py @@ -0,0 +1,230 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import pytest +import abc +import six +from azure.iot.device import ProxyOptions + + +@six.add_metaclass(abc.ABCMeta) +class PipelineConfigInstantiationTestBase(object): + """All PipelineConfig instantiation tests should inherit from this base class. + It provides tests for shared functionality among all PipelineConfigs, derived from + the BasePipelineConfig class. + """ + + @abc.abstractmethod + def config_cls(self): + """This must be implemented in the child test class. + It returns the child class under test""" + pass + + @abc.abstractmethod + def required_kwargs(self): + """This must be implemented in the child test class. + It returns required kwargs for the child class under test""" + pass + + # PipelineConfig objects require exactly one auth mechanism, sastoken or x509. + # For the sake of ease of testing, we will assume sastoken is being used unless + # otherwise specified. + # It does not matter which is used for the purposes of these tests. + + @pytest.fixture + def sastoken(self, mocker): + return mocker.MagicMock() + + @pytest.fixture + def x509(self, mocker): + return mocker.MagicMock() + + @pytest.mark.it( + "Instantiates with the 'hostname' attribute set to the provided 'hostname' parameter" + ) + def test_hostname_set(self, config_cls, required_kwargs, sastoken): + # Hostname is one of the required kwargs, because it is required for the child + config = config_cls(sastoken=sastoken, **required_kwargs) + assert config.hostname == required_kwargs["hostname"] + + @pytest.mark.it( + "Instantiates with the 'gateway_hostname' attribute set to the provided 'gateway_hostname' parameter" + ) + def test_gateway_hostname_set(self, config_cls, required_kwargs, sastoken): + fake_gateway_hostname = "gateway-hostname.some-domain.net" + config = config_cls( + sastoken=sastoken, gateway_hostname=fake_gateway_hostname, **required_kwargs + ) + assert config.gateway_hostname == fake_gateway_hostname + + @pytest.mark.it( + "Instantiates with the 'gateway_hostname' attribute set to 'None' if no 'gateway_hostname' parameter is provided" + ) + def test_gateway_hostname_default(self, config_cls, required_kwargs, sastoken): + config = config_cls(sastoken=sastoken, **required_kwargs) + assert config.gateway_hostname is None + + @pytest.mark.it( + "Instantiates with the 'sastoken' attribute set to the provided 'sastoken' parameter" + ) + def test_sastoken_set(self, config_cls, required_kwargs, sastoken): + config = config_cls(sastoken=sastoken, **required_kwargs) + assert config.sastoken is sastoken + + @pytest.mark.it( + "Instantiates with the 'sastoken' attribute set to 'None' if no 'sastoken' parameter is provided" + ) + def test_sastoken_default(self, config_cls, required_kwargs, x509): + config = config_cls(x509=x509, **required_kwargs) + assert config.sastoken is None + + @pytest.mark.it("Instantiates with the 'x509' attribute set to the provided 'x509' parameter") + def test_x509_set(self, config_cls, required_kwargs, x509): + config = config_cls(x509=x509, **required_kwargs) + assert config.x509 is x509 + + @pytest.mark.it( + "Instantiates with the 'x509' attribute set to 'None' if no 'x509 paramater is provided" + ) + def test_x509_default(self, config_cls, required_kwargs, sastoken): + config = config_cls(sastoken=sastoken, **required_kwargs) + assert config.x509 is None + + @pytest.mark.it( + "Raises a ValueError if neither the 'sastoken' nor 'x509' parameter is provided" + ) + def test_no_auths_provided(self, config_cls, required_kwargs): + with pytest.raises(ValueError): + config_cls(**required_kwargs) + + @pytest.mark.it("Raises a ValueError if both the 'sastoken' and 'x509' parameters are provided") + def test_both_auths_provided(self, config_cls, required_kwargs, sastoken, x509): + with pytest.raises(ValueError): + config_cls(sastoken=sastoken, x509=x509, **required_kwargs) + + @pytest.mark.it( + "Instantiates with the 'server_verification_cert' attribute set to the provided 'server_verification_cert' parameter" + ) + def test_server_verification_cert_set(self, config_cls, required_kwargs, sastoken): + svc = "fake_server_verification_cert" + config = config_cls(sastoken=sastoken, server_verification_cert=svc, **required_kwargs) + assert config.server_verification_cert == svc + + @pytest.mark.it( + "Instantiates with the 'server_verification_cert' attribute set to 'None' if no 'server_verification_cert' paramater is provided" + ) + def test_server_verification_cert_default(self, config_cls, required_kwargs, sastoken): + config = config_cls(sastoken=sastoken, **required_kwargs) + assert config.server_verification_cert is None + + @pytest.mark.it( + "Instantiates with the 'websockets' attribute set to the provided 'websockets' parameter" + ) + @pytest.mark.parametrize( + "websockets", [True, False], ids=["websockets == True", "websockets == False"] + ) + def test_websockets_set(self, config_cls, required_kwargs, sastoken, websockets): + config = config_cls(sastoken=sastoken, websockets=websockets, **required_kwargs) + assert config.websockets is websockets + + @pytest.mark.it( + "Instantiates with the 'websockets' attribute to 'False' if no 'websockets' parameter is provided" + ) + def test_websockets_default(self, config_cls, required_kwargs, sastoken): + config = config_cls(sastoken=sastoken, **required_kwargs) + assert config.websockets is False + + @pytest.mark.it( + "Instantiates with the 'cipher' attribute set to OpenSSL list formatted version of the provided 'cipher' parameter" + ) + @pytest.mark.parametrize( + "cipher_input, expected_cipher", + [ + pytest.param( + "DHE-RSA-AES128-SHA", + "DHE-RSA-AES128-SHA", + id="Single cipher suite, OpenSSL list formatted string", + ), + pytest.param( + "DHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA:ECDHE-ECDSA-AES128-GCM-SHA256", + "DHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA:ECDHE-ECDSA-AES128-GCM-SHA256", + id="Multiple cipher suites, OpenSSL list formatted string", + ), + pytest.param( + "DHE_RSA_AES128_SHA", + "DHE-RSA-AES128-SHA", + id="Single cipher suite, as string with '_' delimited algorithms/protocols", + ), + pytest.param( + "DHE_RSA_AES128_SHA:DHE_RSA_AES256_SHA:ECDHE_ECDSA_AES128_GCM_SHA256", + "DHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA:ECDHE-ECDSA-AES128-GCM-SHA256", + id="Multiple cipher suites, as string with '_' delimited algorithms/protocols and ':' delimited suites", + ), + pytest.param( + ["DHE-RSA-AES128-SHA"], + "DHE-RSA-AES128-SHA", + id="Single cipher suite, in a list, with '-' delimited algorithms/protocols", + ), + pytest.param( + ["DHE-RSA-AES128-SHA", "DHE-RSA-AES256-SHA", "ECDHE-ECDSA-AES128-GCM-SHA256"], + "DHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA:ECDHE-ECDSA-AES128-GCM-SHA256", + id="Multiple cipher suites, in a list, with '-' delimited algorithms/protocols", + ), + pytest.param( + ["DHE_RSA_AES128_SHA"], + "DHE-RSA-AES128-SHA", + id="Single cipher suite, in a list, with '_' delimited algorithms/protocols", + ), + pytest.param( + ["DHE_RSA_AES128_SHA", "DHE_RSA_AES256_SHA", "ECDHE_ECDSA_AES128_GCM_SHA256"], + "DHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA:ECDHE-ECDSA-AES128-GCM-SHA256", + id="Multiple cipher suites, in a list, with '_' delimited algorithms/protocols", + ), + ], + ) + def test_cipher(self, config_cls, required_kwargs, sastoken, cipher_input, expected_cipher): + config = config_cls(sastoken=sastoken, cipher=cipher_input, **required_kwargs) + assert config.cipher == expected_cipher + + @pytest.mark.it( + "Raises TypeError if the provided 'cipher' attribute is neither list nor string" + ) + @pytest.mark.parametrize( + "cipher", + [ + pytest.param(123, id="int"), + pytest.param( + {"cipher1": "DHE-RSA-AES128-SHA", "cipher2": "DHE_RSA_AES256_SHA"}, id="dict" + ), + pytest.param(object(), id="complex object"), + ], + ) + def test_invalid_cipher_param(self, config_cls, required_kwargs, sastoken, cipher): + with pytest.raises(TypeError): + config_cls(sastoken=sastoken, cipher=cipher, **required_kwargs) + + @pytest.mark.it( + "Instantiates with the 'cipher' attribute to empty string ('') if no 'cipher' parameter is provided" + ) + def test_cipher_default(self, config_cls, required_kwargs, sastoken): + config = config_cls(sastoken=sastoken, **required_kwargs) + assert config.cipher == "" + + @pytest.mark.it( + "Instantiates with the 'proxy_options' attribute set to the ProxyOptions object provided in the 'proxy_options' parameter" + ) + def test_proxy_options(self, mocker, required_kwargs, config_cls, sastoken): + proxy_options = ProxyOptions( + proxy_type=mocker.MagicMock(), proxy_addr="127.0.0.1", proxy_port=8888 + ) + config = config_cls(sastoken=sastoken, proxy_options=proxy_options, **required_kwargs) + assert config.proxy_options is proxy_options + + @pytest.mark.it( + "Instantiates with the 'proxy_options' attribute to 'None' if no 'proxy_options' parameter is provided" + ) + def test_proxy_options_default(self, config_cls, required_kwargs, sastoken): + config = config_cls(sastoken=sastoken, **required_kwargs) + assert config.proxy_options is None diff --git a/azure-iot-device/tests/common/pipeline/helpers.py b/azure-iot-device/tests/common/pipeline/helpers.py index c5c0ab818..eae8c7229 100644 --- a/azure-iot-device/tests/common/pipeline/helpers.py +++ b/azure-iot-device/tests/common/pipeline/helpers.py @@ -33,14 +33,11 @@ class StageRunOpTestBase(object): ) def test_completes_operation_with_error(self, mocker, stage, op, arbitrary_exception): stage._run_op = mocker.MagicMock(side_effect=arbitrary_exception) - # mocker.spy(op, "complete") stage.run_op(op) assert op.completed assert op.error is arbitrary_exception - # assert op.complete.call_count == 1 - # assert op.complete.call_args == mocker.call(error=arbitrary_exception) @pytest.mark.it( "Allows any BaseException that was raised during execution of the operation to propogate" @@ -79,177 +76,3 @@ def test_base_exception_propogates(self, mocker, stage, event, arbitrary_base_ex with pytest.raises(arbitrary_base_exception.__class__) as e_info: stage.handle_pipeline_event(event) assert e_info.value is arbitrary_base_exception - - -############################################ -# EVERYTHING BELOW THIS POINT IS DEPRECATED# -############################################ -# CT-TODO: remove - -all_common_ops = [ - pipeline_ops_base.ConnectOperation, - pipeline_ops_base.ReauthorizeConnectionOperation, - pipeline_ops_base.DisconnectOperation, - pipeline_ops_base.EnableFeatureOperation, - pipeline_ops_base.DisableFeatureOperation, - pipeline_ops_base.UpdateSasTokenOperation, - pipeline_ops_base.RequestAndResponseOperation, - pipeline_ops_base.RequestOperation, - pipeline_ops_mqtt.SetMQTTConnectionArgsOperation, - pipeline_ops_mqtt.MQTTPublishOperation, - pipeline_ops_mqtt.MQTTSubscribeOperation, - pipeline_ops_mqtt.MQTTUnsubscribeOperation, -] - -all_common_events = [pipeline_events_mqtt.IncomingMQTTMessageEvent] - - -def all_except(all_items, items_to_exclude): - """ - helper function to return a new list with all ops that are in the first list - and not in the second list. - - :param list all_items: list of all operations or events - :param list items_to_exclude: ops or events to exclude - """ - return [x for x in all_items if x not in items_to_exclude] - - -class StageTestBase(object): - @pytest.fixture(autouse=True) - def stage_base_configuration(self, stage, mocker): - """ - This fixture configures the stage for testing. This is automatically - applied, so it will be called before your test runs, but it's not - guaranteed to be called before any other fixtures run. If you have - a fixture that needs to rely on the stage being configured, then - you have to add a manual dependency inside that fixture (like we do in - next_stage_succeeds_all_ops below) - """ - - class NextStageForTest(pipeline_stages_base.PipelineStage): - def _run_op(self, op): - pass - - next = NextStageForTest() - root = ( - pipeline_stages_base.PipelineRootStage(config.BasePipelineConfig()) - .append_stage(stage) - .append_stage(next) - ) - - mocker.spy(stage, "_run_op") - mocker.spy(stage, "run_op") - - mocker.spy(next, "_run_op") - mocker.spy(next, "run_op") - - return root - - @pytest.fixture - def next_stage_succeeds(self, stage, stage_base_configuration, mocker): - def complete_op_success(op): - op.complete() - - stage.next._run_op = complete_op_success - mocker.spy(stage.next, "_run_op") - - @pytest.fixture - def next_stage_raises_arbitrary_exception( - self, stage, stage_base_configuration, mocker, arbitrary_exception - ): - stage.next._run_op = mocker.MagicMock(side_effect=arbitrary_exception) - - @pytest.fixture - def next_stage_raises_arbitrary_base_exception( - self, stage, stage_base_configuration, mocker, arbitrary_base_exception - ): - stage.next._run_op = mocker.MagicMock(side_effect=arbitrary_base_exception) - - -def assert_callback_succeeded(op, callback=None): - if not callback: - callback = op.callback - try: - # if the callback has a __wrapped__ attribute, that means that the - # pipeline added a wrapper around the callback, so we want to look - # at the original function instead of the wrapped function. - callback = callback.__wrapped__ - except AttributeError: - pass - assert callback.call_count == 1 - callback_op_arg = callback.call_args[0][0] - assert callback_op_arg == op - callback_error_arg = callback.call_args[1]["error"] - assert callback_error_arg is None - - -def assert_callback_failed(op, callback=None, error=None): - if not callback: - callback = op.callback - try: - # if the callback has a __wrapped__ attribute, that means that the - # pipeline added a wrapper around the callback, so we want to look - # at the original function instead of the wrapped function. - callback = callback.__wrapped__ - except AttributeError: - pass - assert callback.call_count == 1 - callback_op_arg = callback.call_args[0][0] - assert callback_op_arg == op - - callback_error_arg = callback.call_args[1]["error"] - if error: - if isinstance(error, type): - assert callback_error_arg.__class__ == error - else: - assert callback_error_arg is error - else: - assert callback_error_arg is not None - - -def get_arg_count(fn): - """ - return the number of arguments (args) passed into a - particular function. Returned value does not include kwargs. - """ - try: - # if __wrapped__ is set, we're looking at a decorated function - # Functools.wraps doesn't copy arg metadata, so we need to - # get argument count from the wrapped function instead. - return len(getargspec(fn.__wrapped__).args) - except AttributeError: - return len(getargspec(fn).args) - - -def make_mock_op_or_event(cls): - args = [None for i in (range(get_arg_count(cls.__init__) - 1))] - return cls(*args) - - -def add_mock_method_waiter(obj, method_name): - """ - For mock methods, add "wait_for_xxx_to_be_called" and "wait_for_xxx_to_not_be_called" - helper functions on the object. This is very handy for methods that get called by - another thread, when you want your test functions to wait until the other thread is - able to call the method without using a sleep call. - """ - method_called = Event() - - def signal_method_called(*args, **kwargs): - method_called.set() - - def wait_for_method_to_be_called(): - method_called.wait(0.1) - assert method_called.isSet() - method_called.clear() - - def wait_for_method_to_not_be_called(): - method_called.wait(0.1) - assert not method_called.isSet() - - getattr(obj, method_name).side_effect = signal_method_called - setattr(obj, "wait_for_{}_to_be_called".format(method_name), wait_for_method_to_be_called) - setattr( - obj, "wait_for_{}_to_not_be_called".format(method_name), wait_for_method_to_not_be_called - ) diff --git a/azure-iot-device/tests/common/pipeline/pipeline_config_test.py b/azure-iot-device/tests/common/pipeline/pipeline_config_test.py deleted file mode 100644 index 21fbf4750..000000000 --- a/azure-iot-device/tests/common/pipeline/pipeline_config_test.py +++ /dev/null @@ -1,124 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import pytest -from azure.iot.device import ProxyOptions - - -class PipelineConfigInstantiationTestBase(object): - """All PipelineConfig instantiation tests should inherit from this base class. - It provides tests for shared functionality among all PipelineConfigs, derived from - the BasePipelineConfig class. - """ - - @pytest.mark.it( - "Instantiates with the 'websockets' attribute set to the provided 'websockets' parameter" - ) - @pytest.mark.parametrize( - "websockets", [True, False], ids=["websockets == True", "websockets == False"] - ) - def test_websockets_set(self, config_cls, websockets): - config = config_cls(websockets=websockets) - assert config.websockets is websockets - - @pytest.mark.it( - "Instantiates with the 'websockets' attribute to 'False' if no 'websockets' parameter is provided" - ) - def test_websockets_default(self, config_cls): - config = config_cls() - assert config.websockets is False - - @pytest.mark.it( - "Instantiates with the 'cipher' attribute set to OpenSSL list formatted version of the provided 'cipher' parameter" - ) - @pytest.mark.parametrize( - "cipher_input, expected_cipher", - [ - pytest.param( - "DHE-RSA-AES128-SHA", - "DHE-RSA-AES128-SHA", - id="Single cipher suite, OpenSSL list formatted string", - ), - pytest.param( - "DHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA:ECDHE-ECDSA-AES128-GCM-SHA256", - "DHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA:ECDHE-ECDSA-AES128-GCM-SHA256", - id="Multiple cipher suites, OpenSSL list formatted string", - ), - pytest.param( - "DHE_RSA_AES128_SHA", - "DHE-RSA-AES128-SHA", - id="Single cipher suite, as string with '_' delimited algorithms/protocols", - ), - pytest.param( - "DHE_RSA_AES128_SHA:DHE_RSA_AES256_SHA:ECDHE_ECDSA_AES128_GCM_SHA256", - "DHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA:ECDHE-ECDSA-AES128-GCM-SHA256", - id="Multiple cipher suites, as string with '_' delimited algorithms/protocols and ':' delimited suites", - ), - pytest.param( - ["DHE-RSA-AES128-SHA"], - "DHE-RSA-AES128-SHA", - id="Single cipher suite, in a list, with '-' delimited algorithms/protocols", - ), - pytest.param( - ["DHE-RSA-AES128-SHA", "DHE-RSA-AES256-SHA", "ECDHE-ECDSA-AES128-GCM-SHA256"], - "DHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA:ECDHE-ECDSA-AES128-GCM-SHA256", - id="Multiple cipher suites, in a list, with '-' delimited algorithms/protocols", - ), - pytest.param( - ["DHE_RSA_AES128_SHA"], - "DHE-RSA-AES128-SHA", - id="Single cipher suite, in a list, with '_' delimited algorithms/protocols", - ), - pytest.param( - ["DHE_RSA_AES128_SHA", "DHE_RSA_AES256_SHA", "ECDHE_ECDSA_AES128_GCM_SHA256"], - "DHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA:ECDHE-ECDSA-AES128-GCM-SHA256", - id="Multiple cipher suites, in a list, with '_' delimited algorithms/protocols", - ), - ], - ) - def test_cipher(self, config_cls, cipher_input, expected_cipher): - config = config_cls(cipher=cipher_input) - assert config.cipher == expected_cipher - - @pytest.mark.it( - "Raises TypeError if the provided 'cipher' attribute is neither list nor string" - ) - @pytest.mark.parametrize( - "cipher", - [ - pytest.param(123, id="int"), - pytest.param( - {"cipher1": "DHE-RSA-AES128-SHA", "cipher2": "DHE_RSA_AES256_SHA"}, id="dict" - ), - pytest.param(object(), id="complex object"), - ], - ) - def test_invalid_cipher_param(self, config_cls, cipher): - with pytest.raises(TypeError): - config_cls(cipher=cipher) - - @pytest.mark.it( - "Instantiates with the 'cipher' attribute to empty string ('') if no 'cipher' parameter is provided" - ) - def test_cipher_default(self, config_cls): - config = config_cls() - assert config.cipher == "" - - @pytest.mark.it( - "Instantiates with the 'proxy_options' attribute set to the ProxyOptions object provided in the 'proxy_options' parameter" - ) - def test_proxy_options(self, mocker, config_cls): - proxy_options = ProxyOptions( - proxy_type=mocker.MagicMock(), proxy_addr="127.0.0.1", proxy_port=8888 - ) - config = config_cls(proxy_options=proxy_options) - assert config.proxy_options is proxy_options - - @pytest.mark.it( - "Instantiates with the 'proxy_options' attribute to 'None' if no 'proxy_options' parameter is provided" - ) - def test_proxy_options_default(self, config_cls): - config = config_cls() - assert config.proxy_options is None diff --git a/azure-iot-device/tests/common/pipeline/pipeline_event_test.py b/azure-iot-device/tests/common/pipeline/pipeline_event_test.py index 3c28713a1..3d09cf2c2 100644 --- a/azure-iot-device/tests/common/pipeline/pipeline_event_test.py +++ b/azure-iot-device/tests/common/pipeline/pipeline_event_test.py @@ -8,7 +8,7 @@ fake_count = 0 -# CT-TODO: refactor this module +# CT-TODO: refactor this module in order to be more like pipeline_ops_test.py def get_next_fake_value(): diff --git a/azure-iot-device/tests/common/pipeline/pipeline_stage_test.py b/azure-iot-device/tests/common/pipeline/pipeline_stage_test.py index 7e8fa05ab..8759608d4 100644 --- a/azure-iot-device/tests/common/pipeline/pipeline_stage_test.py +++ b/azure-iot-device/tests/common/pipeline/pipeline_stage_test.py @@ -5,12 +5,7 @@ # -------------------------------------------------------------------------- import logging import pytest -from tests.common.pipeline.helpers import ( - all_except, - make_mock_op_or_event, - StageRunOpTestBase, - StageHandlePipelineEventTestBase, -) +from tests.common.pipeline.helpers import StageRunOpTestBase, StageHandlePipelineEventTestBase from azure.iot.device.common.pipeline.pipeline_stages_base import PipelineStage, PipelineRootStage from azure.iot.device.common.pipeline import pipeline_exceptions from azure.iot.device.common import handle_exceptions @@ -187,112 +182,3 @@ def test_passes_up(self, mocker, stage, event): "Test{}HandlePipelineEventUnhandledEvent".format(stage_class_under_test.__name__), StageHandlePipelineEventUnhandledEvent, ) - - -############################################################# -# CODE BELOW THIS POINT IS DEPRECATED PENDING TEST OVERHAUL # -############################################################# - -# CT-TODO: Remove this as soon as possible - - -def add_base_pipeline_stage_tests_old( - cls, - module, - all_ops, - handled_ops, - all_events, - handled_events, - extra_initializer_defaults={}, - positional_arguments=[], - keyword_arguments={}, -): - """ - Add all of the "basic" tests for validating a pipeline stage. This includes tests for - instantiation and tests for properly handling "unhandled" operations and events". - """ - - # NOTE: this infrastructure has been disabled, resulting in a reduction in test coverage. - # Please port all stage tests to the new version of this function above to remedy - # this problem. - - # add_instantiation_test( - # cls=cls, - # module=module, - # defaults={"name": cls.__name__, "next": None, "previous": None, "pipeline_root": None}, - # extra_defaults=extra_initializer_defaults, - # positional_arguments=positional_arguments, - # keyword_arguments=keyword_arguments, - # ) - _add_unknown_ops_tests(cls=cls, module=module, all_ops=all_ops, handled_ops=handled_ops) - _add_unknown_events_tests( - cls=cls, module=module, all_events=all_events, handled_events=handled_events - ) - - -def _add_unknown_ops_tests(cls, module, all_ops, handled_ops): - """ - Add tests for properly handling of "unknown operations," which are operations that aren't - handled by a particular stage. These operations should be passed down by any stage into - the stages that follow. - """ - unknown_ops = all_except(all_items=all_ops, items_to_exclude=handled_ops) - - @pytest.mark.describe("{} - .run_op() -- unknown and unhandled operations".format(cls.__name__)) - class LocalTestObject(StageRunOpTestBase): - @pytest.fixture(params=unknown_ops) - def op(self, request, mocker): - op = make_mock_op_or_event(request.param) - op.callback_stack.append(mocker.MagicMock()) - return op - - @pytest.fixture - def stage(self): - if cls == PipelineRootStage: - return cls(None) - else: - return cls() - - @pytest.mark.it("Passes unknown operation down to the next stage") - def test_passes_op_to_next_stage(self, mocker, op, stage): - mocker.spy(stage, "send_op_down") - stage.run_op(op) - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(op) - - setattr(module, "Test{}UnknownOps".format(cls.__name__), LocalTestObject) - - -def _add_unknown_events_tests(cls, module, all_events, handled_events): - """ - Add tests for properly handling of "unknown events," which are events that aren't - handled by a particular stage. These operations should be passed up by any stage into - the stages that proceed it.. - """ - - unknown_events = all_except(all_items=all_events, items_to_exclude=handled_events) - - if not unknown_events: - return - - @pytest.mark.describe( - "{} - .handle_pipeline_event() -- unknown and unhandled events".format(cls.__name__) - ) - class LocalTestObject(StageHandlePipelineEventTestBase): - @pytest.fixture(params=unknown_events) - def event(self, request): - return make_mock_op_or_event(request.param) - - @pytest.fixture - def stage(self): - return cls() - - @pytest.mark.it("Passes unknown event to previous stage") - def test_passes_event_to_previous_stage(self, stage, event, mocker): - mocker.spy(stage, "send_event_up") - stage.handle_pipeline_event(event) - - assert stage.send_event_up.call_count == 1 - assert stage.send_event_up.call_args == mocker.call(event) - - setattr(module, "Test{}UnknownEvents".format(cls.__name__), LocalTestObject) diff --git a/azure-iot-device/tests/common/pipeline/test_pipeline_ops_base.py b/azure-iot-device/tests/common/pipeline/test_pipeline_ops_base.py index 9e0400063..6d3b2e194 100644 --- a/azure-iot-device/tests/common/pipeline/test_pipeline_ops_base.py +++ b/azure-iot-device/tests/common/pipeline/test_pipeline_ops_base.py @@ -14,6 +14,24 @@ pytestmark = pytest.mark.usefixtures("fake_pipeline_thread") +class InitializePipelineOperationTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_ops_base.InitializePipelineOperation + + @pytest.fixture + def init_kwargs(self, mocker): + kwargs = {"callback": mocker.MagicMock()} + return kwargs + + +pipeline_ops_test.add_operation_tests( + test_module=this_module, + op_class_under_test=pipeline_ops_base.InitializePipelineOperation, + op_test_config_class=InitializePipelineOperationTestConfig, +) + + class ConnectOperationTestConfig(object): @pytest.fixture def cls_type(self): @@ -140,32 +158,6 @@ def test_feature_name(self, cls_type, init_kwargs): ) -class UpdateSasTokenOperationTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_ops_base.UpdateSasTokenOperation - - @pytest.fixture - def init_kwargs(self, mocker): - kwargs = {"sas_token": "some_token", "callback": mocker.MagicMock()} - return kwargs - - -class UpdateSasTokenOperationInstantiationTests(UpdateSasTokenOperationTestConfig): - @pytest.mark.it("Initializes 'sas_token' attribute with the provided 'sas_token' parameter") - def test_sas_token(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.sas_token == init_kwargs["sas_token"] - - -pipeline_ops_test.add_operation_tests( - test_module=this_module, - op_class_under_test=pipeline_ops_base.UpdateSasTokenOperation, - op_test_config_class=UpdateSasTokenOperationTestConfig, - extended_op_instantiation_test_class=UpdateSasTokenOperationInstantiationTests, -) - - class RequestAndResponseOperationTestConfig(object): @pytest.fixture def cls_type(self): diff --git a/azure-iot-device/tests/common/pipeline/test_pipeline_ops_http.py b/azure-iot-device/tests/common/pipeline/test_pipeline_ops_http.py index 6a69ab2c0..142ef93d4 100644 --- a/azure-iot-device/tests/common/pipeline/test_pipeline_ops_http.py +++ b/azure-iot-device/tests/common/pipeline/test_pipeline_ops_http.py @@ -14,79 +14,6 @@ pytestmark = pytest.mark.usefixtures("fake_pipeline_thread") -class SetHTTPConnectionArgsOperationTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_ops_http.SetHTTPConnectionArgsOperation - - @pytest.fixture - def init_kwargs(self, mocker): - kwargs = { - "hostname": "some_hostname", - "callback": mocker.MagicMock(), - "server_verification_cert": "some_server_verification_cert", - "client_cert": "some_client_cert", - "sas_token": "some_sas_token", - } - return kwargs - - -class SetHTTPConnectionArgsOperationInstantiationTests(SetHTTPConnectionArgsOperationTestConfig): - @pytest.mark.it("Initializes 'hostname' attribute with the provided 'hostname' parameter") - def test_hostname(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.hostname == init_kwargs["hostname"] - - @pytest.mark.it( - "Initializes 'server_verification_cert' attribute with the provided 'server_verification_cert' parameter" - ) - def test_server_verification_cert(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.server_verification_cert == init_kwargs["server_verification_cert"] - - @pytest.mark.it( - "Initializes 'server_verification_cert' attribute to None if no 'server_verification_cert' parameter is provided" - ) - def test_server_verification_cert_default(self, cls_type, init_kwargs): - del init_kwargs["server_verification_cert"] - op = cls_type(**init_kwargs) - assert op.server_verification_cert is None - - @pytest.mark.it("Initializes 'client_cert' attribute with the provided 'client_cert' parameter") - def test_client_cert(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.client_cert == init_kwargs["client_cert"] - - @pytest.mark.it( - "Initializes 'client_cert' attribute to None if no 'client_cert' parameter is provided" - ) - def test_client_cert_default(self, cls_type, init_kwargs): - del init_kwargs["client_cert"] - op = cls_type(**init_kwargs) - assert op.client_cert is None - - @pytest.mark.it("Initializes 'sas_token' attribute with the provided 'sas_token' parameter") - def test_sas_token(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.sas_token == init_kwargs["sas_token"] - - @pytest.mark.it( - "Initializes 'sas_token' attribute to None if no 'sas_token' parameter is provided" - ) - def test_sas_token_default(self, cls_type, init_kwargs): - del init_kwargs["sas_token"] - op = cls_type(**init_kwargs) - assert op.sas_token is None - - -pipeline_ops_test.add_operation_tests( - test_module=this_module, - op_class_under_test=pipeline_ops_http.SetHTTPConnectionArgsOperation, - op_test_config_class=SetHTTPConnectionArgsOperationTestConfig, - extended_op_instantiation_test_class=SetHTTPConnectionArgsOperationInstantiationTests, -) - - class HTTPRequestAndResponseOperationTestConfig(object): @pytest.fixture def cls_type(self): @@ -151,7 +78,7 @@ def test_reason(self, cls_type, init_kwargs): pipeline_ops_test.add_operation_tests( test_module=this_module, - op_class_under_test=pipeline_ops_http.SetHTTPConnectionArgsOperation, + op_class_under_test=pipeline_ops_http.HTTPRequestAndResponseOperation, op_test_config_class=HTTPRequestAndResponseOperationTestConfig, extended_op_instantiation_test_class=HTTPRequestAndResponseOperationInstantiationTests, ) diff --git a/azure-iot-device/tests/common/pipeline/test_pipeline_ops_mqtt.py b/azure-iot-device/tests/common/pipeline/test_pipeline_ops_mqtt.py index 52c6f2e85..673d899d4 100644 --- a/azure-iot-device/tests/common/pipeline/test_pipeline_ops_mqtt.py +++ b/azure-iot-device/tests/common/pipeline/test_pipeline_ops_mqtt.py @@ -14,91 +14,6 @@ pytestmark = pytest.mark.usefixtures("fake_pipeline_thread") -class SetMQTTConnectionArgsOperationTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_ops_mqtt.SetMQTTConnectionArgsOperation - - @pytest.fixture - def init_kwargs(self, mocker): - kwargs = { - "client_id": "some_client_id", - "hostname": "some_hostname", - "username": "some_username", - "callback": mocker.MagicMock(), - "server_verification_cert": "some_server_verification_cert", - "client_cert": "some_client_cert", - "sas_token": "some_sas_token", - } - return kwargs - - -class SetMQTTConnectionArgsOperationInstantiationTests(SetMQTTConnectionArgsOperationTestConfig): - @pytest.mark.it("Initializes 'client_id' attribute with the provided 'client_id' parameter") - def test_client_id(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.client_id == init_kwargs["client_id"] - - @pytest.mark.it("Initializes 'hostname' attribute with the provided 'hostname' parameter") - def test_hostname(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.hostname == init_kwargs["hostname"] - - @pytest.mark.it("Initializes 'username' attribute with the provided 'username' parameter") - def test_username(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.username == init_kwargs["username"] - - @pytest.mark.it( - "Initializes 'server_verification_cert' attribute with the provided 'server_verification_cert' parameter" - ) - def test_server_verification_cert(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.server_verification_cert == init_kwargs["server_verification_cert"] - - @pytest.mark.it( - "Initializes 'server_verification_cert' attribute to None if no 'server_verification_cert' parameter is provided" - ) - def test_server_verification_cert_default(self, cls_type, init_kwargs): - del init_kwargs["server_verification_cert"] - op = cls_type(**init_kwargs) - assert op.server_verification_cert is None - - @pytest.mark.it("Initializes 'client_cert' attribute with the provided 'client_cert' parameter") - def test_client_cert(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.client_cert == init_kwargs["client_cert"] - - @pytest.mark.it( - "Initializes 'client_cert' attribute to None if no 'client_cert' parameter is provided" - ) - def test_client_cert_default(self, cls_type, init_kwargs): - del init_kwargs["client_cert"] - op = cls_type(**init_kwargs) - assert op.client_cert is None - - @pytest.mark.it("Initializes 'sas_token' attribute with the provided 'sas_token' parameter") - def test_sas_token(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.sas_token == init_kwargs["sas_token"] - - @pytest.mark.it( - "Initializes 'sas_token' attribute to None if no 'sas_token' parameter is provided" - ) - def test_sas_token_default(self, cls_type, init_kwargs): - del init_kwargs["sas_token"] - op = cls_type(**init_kwargs) - assert op.sas_token is None - - -pipeline_ops_test.add_operation_tests( - test_module=this_module, - op_class_under_test=pipeline_ops_mqtt.SetMQTTConnectionArgsOperation, - op_test_config_class=SetMQTTConnectionArgsOperationTestConfig, - extended_op_instantiation_test_class=SetMQTTConnectionArgsOperationInstantiationTests, -) - - class MQTTPublishOperationTestConfig(object): @pytest.fixture def cls_type(self): diff --git a/azure-iot-device/tests/common/pipeline/test_pipeline_stages_base.py b/azure-iot-device/tests/common/pipeline/test_pipeline_stages_base.py index 3f2dff8e3..6ac723901 100644 --- a/azure-iot-device/tests/common/pipeline/test_pipeline_stages_base.py +++ b/azure-iot-device/tests/common/pipeline/test_pipeline_stages_base.py @@ -219,6 +219,361 @@ def test_invoke_handler(self, mocker, stage, event): assert mock_handler.call_args == mocker.call(event) +########################### +# SAS TOKEN RENEWAL STAGE # +########################### + + +class SasTokenRenewalStageTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_stages_base.SasTokenRenewalStage + + @pytest.fixture + def init_kwargs(self, mocker): + return {} + + @pytest.fixture + def stage(self, mocker, cls_type, init_kwargs): + stage = cls_type(**init_kwargs) + stage.pipeline_root = pipeline_stages_base.PipelineRootStage( + pipeline_configuration=mocker.MagicMock() + ) + # Add mock SasToken + mock_sastoken = mocker.MagicMock() + mock_sastoken.ttl = 10000 + stage.pipeline_root.pipeline_configuration.sastoken = mock_sastoken + # Mock flow methods + stage.send_op_down = mocker.MagicMock() + stage.send_event_up = mocker.MagicMock() + return stage + + +class SasTokenRenewalStageInstantationTests(SasTokenRenewalStageTestConfig): + @pytest.mark.it("Initializes with the token renewal timer set to 'None'") + def test_token_renewal_timer(self, init_kwargs): + stage = pipeline_stages_base.SasTokenRenewalStage(**init_kwargs) + assert stage._token_renewal_timer is None + + @pytest.mark.it("Uses 120 seconds as the Renewal Margin by default") + def test_renewal_margin(self, init_kwargs): + # NOTE: currently, renewal margin isn't set as an instance attribute really, it just uses + # a constant defined on the class in all cases. Eventually this logic may be expanded to + # be more dynamic, and this test will need to change + stage = pipeline_stages_base.SasTokenRenewalStage(**init_kwargs) + assert stage.DEFAULT_TOKEN_RENEWAL_MARGIN == 120 + + +pipeline_stage_test.add_base_pipeline_stage_tests( + test_module=this_module, + stage_class_under_test=pipeline_stages_base.SasTokenRenewalStage, + stage_test_config_class=SasTokenRenewalStageTestConfig, + extended_stage_instantiation_test_class=SasTokenRenewalStageInstantationTests, +) + + +@pytest.mark.describe( + "SasTokenRenewalStage - .run_op() -- Called with InitializePipelineOperation, on a pipeline configured with SAS authentication" +) +class TestSasTokenRenewalStageRunOpWithInitializePipelineOpSasTokenConfig( + SasTokenRenewalStageTestConfig, StageRunOpTestBase +): + @pytest.fixture + def op(self, mocker): + return pipeline_ops_base.InitializePipelineOperation(callback=mocker.MagicMock()) + + @pytest.mark.it("Cancels any existing token renewal timer that may have been set") + def test_cancels_existing_timer(self, mocker, stage, op): + mock_timer = mocker.MagicMock() + stage._token_renewal_timer = mock_timer + + stage.run_op(op) + + assert mock_timer.cancel.call_count == 1 + assert mock_timer.cancel.call_args == mocker.call() + + @pytest.mark.it("Resets the token renewal timer to None until a new one is set") + # Edge case, since unless something goes wrong, the timer WILL be set, and it's like + # it was never set to None. + def test_timer_set_to_none_in_intermediate( + self, mocker, stage, op, mock_timer, arbitrary_exception + ): + # Set an existing timer + stage._token_renewal_timer = mocker.MagicMock() + + # Set an error side effect on the timer creation, so when a new timer is created, + # we have an unhandled error causing op failure and early exit + mock_timer.side_effect = arbitrary_exception + + stage.run_op(op) + + assert op.complete + assert op.error is arbitrary_exception + assert stage._token_renewal_timer is None + + @pytest.mark.it( + "Starts a background renewal timer for 'Renewal Margin' number of seconds prior to SasToken expiration" + ) + def test_sets_timer(self, mocker, stage, op, mock_timer): + expected_timer_seconds = ( + stage.pipeline_root.pipeline_configuration.sastoken.ttl + - pipeline_stages_base.SasTokenRenewalStage.DEFAULT_TOKEN_RENEWAL_MARGIN + ) + + stage.run_op(op) + + assert mock_timer.call_count == 1 + assert mock_timer.call_args[0][0] == expected_timer_seconds + assert mock_timer.return_value.daemon is True + assert mock_timer.return_value.start.call_count == 1 + assert mock_timer.return_value.start.call_args == mocker.call() + + @pytest.mark.it( + "Sends a PipelineError to the background exception handler and does not set a timer if the SasToken TTL is less than the Renewal Margin (time prior to token expiration triggering renew)" + ) + def test_token_ttl_less_than_renewal_timer(self, mocker, stage, op, mock_timer): + # NOTE: this really shouldn't happen in regular flow. This is a total edge case, that is + # likely only possible if a bug exists elsewhere in the stack + stage.pipeline_root.pipeline_configuration.sastoken.ttl = ( + pipeline_stages_base.SasTokenRenewalStage.DEFAULT_TOKEN_RENEWAL_MARGIN - 1 + ) + mocker.spy(handle_exceptions, "handle_background_exception") + + stage.run_op(op) + + assert handle_exceptions.handle_background_exception.call_count == 1 + assert isinstance( + handle_exceptions.handle_background_exception.call_args[0][0], + pipeline_exceptions.PipelineError, + ) + assert mock_timer.call_count == 0 + + +@pytest.mark.describe( + "SasTokenRenewalStage - .run_op() -- Called with InitializePipelineOperation, on a pipeline NOT configured with SAS authentication" +) +class TestSasTokenRenewalStageRunOpWithInitializePipelineOpNoSasTokenConfig( + SasTokenRenewalStageTestConfig, StageRunOpTestBase +): + @pytest.fixture + def op(self, mocker): + return pipeline_ops_base.InitializePipelineOperation(callback=mocker.MagicMock()) + + # Override inherited fixture so that there is NO sastoken + @pytest.fixture + def stage(self, mocker, cls_type, init_kwargs): + stage = cls_type(**init_kwargs) + stage.pipeline_root = pipeline_stages_base.PipelineRootStage( + pipeline_configuration=mocker.MagicMock() + ) + # No Sastoken + stage.pipeline_root.pipeline_configuration.sastoken = None + # Mock flow methods + stage.send_op_down = mocker.MagicMock() + stage.send_event_up = mocker.MagicMock() + return stage + + @pytest.mark.it("Sends the operation down, WITHOUT setting a renewal timer") + def test_sends_op_down_no_timer(self, mocker, stage, op): + mock_timer = mocker.patch.object(threading, "Timer") + + stage.run_op(op) + + assert stage.send_op_down.call_count == 1 + assert stage.send_op_down.call_args == mocker.call(op) + assert stage._token_renewal_timer is None + assert mock_timer.call_count == 0 + + +@pytest.mark.describe("SasTokenRenewalStage - OCCURANCE: SasToken Renewal Timer expires") +class TestSasTokenRenewalStageOCCURANCETimerExpires(SasTokenRenewalStageTestConfig): + @pytest.fixture + def op(self, mocker): + return pipeline_ops_base.InitializePipelineOperation(callback=mocker.MagicMock()) + + @pytest.mark.it("Refreshes the pipeline's SasToken") + @pytest.mark.parametrize( + "connected", + [ + pytest.param(True, id="Pipeline connected"), + pytest.param(False, id="Pipeline not connected"), + ], + ) + def test_refresh_token(self, stage, op, mock_timer, connected): + # Apply the timer + stage.run_op(op) + + # Set connected state + stage.pipeline_root.connected = connected + + # Token has not been refreshed + token = stage.pipeline_root.pipeline_configuration.sastoken + assert token.refresh.call_count == 0 + assert mock_timer.call_count == 1 + + # Call timer complete callback (as if timer expired) + on_timer_complete = mock_timer.call_args[0][1] + on_timer_complete() + + # Token has now been refreshed + assert token.refresh.call_count == 1 + + @pytest.mark.it( + "Sends a ReauthorizeConnectionOperation down the pipeline if the pipeline is in a 'connected' state" + ) + def test_when_pipeline_connected(self, mocker, stage, op, mock_timer): + # Apply the timer and set stage as connected + stage.pipeline_root.connected = True + stage.run_op(op) + + # Only the InitializePipeline op has been sent down + assert stage.send_op_down.call_count == 1 + assert stage.send_op_down.call_args == mocker.call(op) + + # Pipeline is still connected + assert stage.pipeline_root.connected is True + + # Call timer complete callback (as if timer expired) + assert mock_timer.call_count == 1 + on_timer_complete = mock_timer.call_args[0][1] + on_timer_complete() + + # ReauthorizeConnectionOperation has now been sent down + assert stage.send_op_down.call_count == 2 + assert isinstance( + stage.send_op_down.call_args[0][0], pipeline_ops_base.ReauthorizeConnectionOperation + ) + + @pytest.mark.it( + "Does NOT send a ReauthorizeConnectionOperation down the pipeline if the pipeline is NOT in a 'connected' state" + ) + def test_when_pipeline_not_connected(self, mocker, stage, op, mock_timer): + # Apply the timer and set stage as connected + stage.pipeline_root.connected = False + stage.run_op(op) + + # Only the InitializePipeline op has been sent down + assert stage.send_op_down.call_count == 1 + assert stage.send_op_down.call_args == mocker.call(op) + + # Pipeline is still NOT connected + assert stage.pipeline_root.connected is False + + # Call timer complete callback (as if timer expired) + on_timer_complete = mock_timer.call_args[0][1] + on_timer_complete() + + # No further ops have been sent down + assert stage.send_op_down.call_count == 1 + + @pytest.mark.it( + "If the ReauthorizeConnectionOperation is later completed with an error, send the error to the background exception handler" + ) + def test_reauth_op_error_goes_to_bkg_handler( + self, mocker, stage, op, mock_timer, arbitrary_exception + ): + mocker.spy(handle_exceptions, "handle_background_exception") + + # Apply the timer and set stage as connected + stage.pipeline_root.connected = True + stage.run_op(op) + + # Call timer complete callback (as if timer expired) + assert mock_timer.call_count == 1 + on_timer_complete = mock_timer.call_args[0][1] + on_timer_complete() + + # ReauthorizeConnectionOperation has now been sent down + assert stage.send_op_down.call_count == 2 + reauth_op = stage.send_op_down.call_args[0][0] + assert isinstance(reauth_op, pipeline_ops_base.ReauthorizeConnectionOperation) + + # Complete ReauthorizeConnectionOperation with error + reauth_op.complete(error=arbitrary_exception) + + # Error was sent to background handler + assert handle_exceptions.handle_background_exception.call_count == 1 + assert handle_exceptions.handle_background_exception.call_args == mocker.call( + arbitrary_exception + ) + + @pytest.mark.it("Begins a new SasToken renewal timer") + @pytest.mark.parametrize( + "connected", + [ + pytest.param(True, id="Pipeline connected"), + pytest.param(False, id="Pipeline not connected"), + ], + ) + # I am sorry for this test length, but IDK how else to test this... + # ... other than throwing everything at it at once + def test_new_timer(self, mocker, stage, op, mock_timer, connected): + token = stage.pipeline_root.pipeline_configuration.sastoken + + # Set connected state + stage.pipeline_root.connected = connected + + # Apply the timer + stage.run_op(op) + + # op was passed down + assert stage.send_op_down.call_count == 1 + assert stage.send_op_down.call_args == mocker.call(op) + + # Only one timer has been created and started. No cancellation. + assert mock_timer.call_count == 1 + assert mock_timer.return_value.start.call_count == 1 + assert mock_timer.return_value.cancel.call_count == 0 + + # Call timer complete callback (as if timer expired) + on_timer_complete = mock_timer.call_args[0][1] + on_timer_complete() + + # Existing timer was cancelled + assert mock_timer.return_value.cancel.call_count == 1 + + # Token was refreshed + assert token.refresh.call_count == 1 + + # Reauthorize was sent down (if the connection state was right) + if connected: + assert stage.send_op_down.call_count == 2 + assert isinstance( + stage.send_op_down.call_args[0][0], pipeline_ops_base.ReauthorizeConnectionOperation + ) + else: + assert stage.send_op_down.call_count == 1 + + # Another timer was created and started for the expected time + assert mock_timer.call_count == 2 + expected_timer_seconds = ( + stage.pipeline_root.pipeline_configuration.sastoken.ttl + - pipeline_stages_base.SasTokenRenewalStage.DEFAULT_TOKEN_RENEWAL_MARGIN + ) + assert mock_timer.call_args[0][0] == expected_timer_seconds + assert stage._token_renewal_timer is mock_timer.return_value + assert stage._token_renewal_timer.daemon is True + assert stage._token_renewal_timer.start.call_count == 2 + assert stage._token_renewal_timer.start.call_args == mocker.call() + + # When THAT timer expires, the token is refreshed, and the reauth is sent, etc. etc. etc. + # ... recursion :) + new_on_timer_complete = mock_timer.call_args[0][1] + new_on_timer_complete() + + assert token.refresh.call_count == 2 + if connected: + assert stage.send_op_down.call_count == 3 + assert isinstance( + stage.send_op_down.call_args[0][0], pipeline_ops_base.ReauthorizeConnectionOperation + ) + else: + assert stage.send_op_down.call_count == 1 + + assert mock_timer.call_count == 3 + # .... and on and on for infinity + + ###################### # AUTO CONNECT STAGE # ###################### diff --git a/azure-iot-device/tests/common/pipeline/test_pipeline_stages_http.py b/azure-iot-device/tests/common/pipeline/test_pipeline_stages_http.py index ccd00a6a6..a85ff4d35 100644 --- a/azure-iot-device/tests/common/pipeline/test_pipeline_stages_http.py +++ b/azure-iot-device/tests/common/pipeline/test_pipeline_stages_http.py @@ -61,16 +61,12 @@ def stage(self, mocker, cls_type, init_kwargs): stage.pipeline_root = pipeline_stages_base.PipelineRootStage( pipeline_configuration=mocker.MagicMock() ) + stage.pipeline_root.hostname = "some.fake-host.name.com" stage.send_op_down = mocker.MagicMock() return stage class HTTPTransportInstantiationTests(HTTPTransportStageTestConfig): - @pytest.mark.it("Initializes 'sas_token' attribute as None") - def test_sas_token(self, cls_type, init_kwargs): - stage = cls_type(**init_kwargs) - assert stage.sas_token is None - @pytest.mark.it("Initializes 'transport' attribute as None") def test_transport(self, cls_type, init_kwargs): stage = cls_type(**init_kwargs) @@ -85,26 +81,14 @@ def test_transport(self, cls_type, init_kwargs): ) -@pytest.mark.describe( - "HTTPTransportStage - .run_op() -- Called with SetHTTPConnectionArgsOperation" -) -class TestHTTPTransportStageRunOpCalledWithSetHTTPConnectionArgsOperation( +@pytest.mark.describe("HTTPTransportStage - .run_op() -- Called with InitializePipelineOperation") +class TestHTTPTransportStageRunOpCalledWithInitializePipelineOperation( HTTPTransportStageTestConfig, StageRunOpTestBase ): @pytest.fixture def op(self, mocker): - return pipeline_ops_http.SetHTTPConnectionArgsOperation( - hostname="fake_hostname", - server_verification_cert="fake_server_verification_cert", - client_cert="fake_client_cert", - sas_token="fake_sas_token", - callback=mocker.MagicMock(), - ) - - @pytest.mark.it("Stores the sas_token operation in the 'sas_token' attribute of the stage") - def test_stores_data(self, stage, op, mocker, mock_transport): - stage.run_op(op) - assert stage.sas_token == op.sas_token + op = pipeline_ops_base.InitializePipelineOperation(callback=mocker.MagicMock()) + return op @pytest.mark.it( "Creates an HTTPTransport object and sets it as the 'transport' attribute of the stage (and on the pipeline root)" @@ -120,9 +104,23 @@ def test_stores_data(self, stage, op, mocker, mock_transport): pytest.param("", id="Pipeline NOT configured for custom cipher(s)"), ], ) - def test_creates_transport(self, mocker, stage, op, mock_transport, cipher): + @pytest.mark.parametrize( + "gateway_hostname", + [ + pytest.param("fake.gateway.hostname.com", id="Using Gateway Hostname"), + pytest.param(None, id="Not using Gateway Hostname"), + ], + ) + def test_creates_transport(self, mocker, stage, op, mock_transport, cipher, gateway_hostname): # Setup pipeline config stage.pipeline_root.pipeline_configuration.cipher = cipher + stage.pipeline_root.pipeline_configuration.gateway_hostname = gateway_hostname + + # NOTE: if more of this type of logic crops up, consider splitting this test up + if stage.pipeline_root.pipeline_configuration.gateway_hostname: + expected_hostname = stage.pipeline_root.pipeline_configuration.gateway_hostname + else: + expected_hostname = stage.pipeline_root.pipeline_configuration.hostname assert stage.transport is None @@ -130,9 +128,9 @@ def test_creates_transport(self, mocker, stage, op, mock_transport, cipher): assert mock_transport.call_count == 1 assert mock_transport.call_args == mocker.call( - hostname=op.hostname, - server_verification_cert=op.server_verification_cert, - x509_cert=op.client_cert, + hostname=expected_hostname, + server_verification_cert=stage.pipeline_root.pipeline_configuration.server_verification_cert, + x509_cert=stage.pipeline_root.pipeline_configuration.x509, cipher=cipher, ) assert stage.transport is mock_transport.return_value @@ -145,81 +143,33 @@ def test_succeeds(self, mocker, stage, op, mock_transport): # NOTE: The HTTPTransport object is not instantiated upon instantiation of the HTTPTransportStage. -# It is only added once the SetHTTPConnectionArgsOperation runs. +# It is only added once the InitializePipelineOperation runs. # The lifecycle of the HTTPTransportStage is as follows: # 1. Instantiate the stage -# 2. Configure the stage with a SetHTTPConnectionArgsOperation +# 2. Configure the stage with an InitializePipelineOperation # 3. Run any other desired operations. # -# This is to say, no operation should be running before SetHTTPConnectionArgsOperation. +# This is to say, no operation should be running before InitializePipelineOperation. # Thus, for the following tests, we will assume that the HTTPTransport has already been created, # and as such, the stage fixture used will have already have one. class HTTPTransportStageTestConfigComplex(HTTPTransportStageTestConfig): - # We add a pytest fixture parametrization between SAS an X509 since depending on the version of authentication, the op will be formatted differently. - @pytest.fixture(params=["SAS", "X509"]) - def stage(self, mocker, request, cls_type, init_kwargs): - mock_transport = mocker.patch( - "azure.iot.device.common.pipeline.pipeline_stages_http.HTTPTransport", autospec=True - ) + @pytest.fixture + def stage(self, mocker, request, cls_type, init_kwargs, mock_transport): stage = cls_type(**init_kwargs) stage.pipeline_root = pipeline_stages_base.PipelineRootStage( pipeline_configuration=mocker.MagicMock() ) stage.send_op_down = mocker.MagicMock() + # Set up the Transport on the stage - if request.param == "SAS": - op = pipeline_ops_http.SetHTTPConnectionArgsOperation( - hostname="fake_hostname", - server_verification_cert="fake_server_verification_cert", - sas_token="fake_sas_token", - callback=mocker.MagicMock(), - ) - else: - op = pipeline_ops_http.SetHTTPConnectionArgsOperation( - hostname="fake_hostname", - server_verification_cert="fake_server_verification_cert", - client_cert="fake_client_cert", - callback=mocker.MagicMock(), - ) + op = pipeline_ops_base.InitializePipelineOperation(callback=mocker.MagicMock()) stage.run_op(op) + assert stage.transport is mock_transport.return_value return stage -@pytest.mark.describe("HTTPTransportStage - .run_op() -- Called with UpdateSasTokenOperation") -class TestHTTPTransportStageRunOpCalledWithUpdateSasTokenOperation( - HTTPTransportStageTestConfigComplex, StageRunOpTestBase -): - @pytest.fixture - def op(self, mocker): - return pipeline_ops_base.UpdateSasTokenOperation( - sas_token="new_fake_sas_token", callback=mocker.MagicMock() - ) - - @pytest.mark.it( - "Updates the 'sas_token' attribute to be the new value contained in the operation" - ) - def test_updates_token(self, stage, op): - assert stage.sas_token != op.sas_token - stage.run_op(op) - assert stage.sas_token == op.sas_token - - @pytest.mark.it("Completes the operation with success, upon successful execution") - def test_completes_op(self, stage, op): - assert not op.completed - stage.run_op(op) - assert op.completed - - -fake_method = "__fake_method__" -fake_path = "__fake_path__" -fake_headers = {"__fake_key__": "__fake_value__"} -fake_body = "__fake_body__" -fake_query_params = "__fake_query_params__" -fake_sas_token = "fake_sas_token" - - @pytest.mark.describe( "HTTPTransportStage - .run_op() -- Called with HTTPRequestAndResponseOperation" ) @@ -229,45 +179,58 @@ class TestHTTPTransportStageRunOpCalledWithHTTPRequestAndResponseOperation( @pytest.fixture def op(self, mocker): return pipeline_ops_http.HTTPRequestAndResponseOperation( - method=fake_method, - path=fake_path, - headers=fake_headers, - body=fake_body, - query_params=fake_query_params, + method="SOME_METHOD", + path="fake/path", + headers={"fake_key": "fake_val"}, + body="fake_body", + query_params="arg1=val1;arg2=val2", callback=mocker.MagicMock(), ) @pytest.mark.it("Sends an HTTP request via the HTTPTransport") def test_http_request(self, mocker, stage, op): stage.run_op(op) - # We add this because the default stage here contains a SAS Token. - fake_headers["Authorization"] = fake_sas_token + assert stage.transport.request.call_count == 1 assert stage.transport.request.call_args == mocker.call( - method=fake_method, - path=fake_path, - headers=fake_headers, - body=fake_body, - query_params=fake_query_params, + method=op.method, + path=op.path, + # headers are tested in depth in the following two tests + headers=mocker.ANY, + body=op.body, + query_params=op.query_params, callback=mocker.ANY, ) @pytest.mark.it( - "Does not provide an Authorization header if the SAS Token is not set in the stage" + "Adds the SasToken in the request's 'Authorization' header if using SAS-based authentication" ) - def test_header_with_no_sas(self, mocker, stage, op): - # Manually overwriting stage with no SAS Token. - stage.sas_token = None + def test_headers_with_sas_auth(self, mocker, stage, op): + # A SasToken is set on the pipeline, but Authorization headers have not yet been set + assert stage.pipeline_root.pipeline_configuration.sastoken is not None + assert op.headers.get("Authorization") is None + stage.run_op(op) - assert stage.transport.request.call_count == 1 - assert stage.transport.request.call_args == mocker.call( - method=fake_method, - path=fake_path, - headers=fake_headers, - body=fake_body, - query_params=fake_query_params, - callback=mocker.ANY, - ) + + # Need to get the headers sent to the transport, not provided by the op, due to a + # deep copy that occurs + headers = stage.transport.request.call_args[1]["headers"] + assert headers["Authorization"] == str(stage.pipeline_root.pipeline_configuration.sastoken) + + @pytest.mark.it( + "Does NOT add the 'Authorization' header to the request if NOT using SAS-based authentication" + ) + def test_headers_with_no_sas(self, mocker, stage, op): + # NO SasToken is set on the pipeline, and Authorization headers have not yet been set + stage.pipeline_root.pipeline_configuration.sastoken = None + assert op.headers.get("Authorization") is None + + stage.run_op(op) + + # Need to get the headers sent to the transport, not provided by the op, due to a + # deep copy that occurs + headers = stage.transport.request.call_args[1]["headers"] + assert headers.get("Authorization") is None @pytest.mark.it( "Completes the operation unsucessfully if there is a failure requesting via the HTTPTransport, using the error raised by the HTTPTransport" diff --git a/azure-iot-device/tests/common/pipeline/test_pipeline_stages_mqtt.py b/azure-iot-device/tests/common/pipeline/test_pipeline_stages_mqtt.py index 60f9d62bd..fd31017c1 100644 --- a/azure-iot-device/tests/common/pipeline/test_pipeline_stages_mqtt.py +++ b/azure-iot-device/tests/common/pipeline/test_pipeline_stages_mqtt.py @@ -70,17 +70,13 @@ def stage(self, mocker, cls_type, init_kwargs): stage.pipeline_root = pipeline_stages_base.PipelineRootStage( pipeline_configuration=mocker.MagicMock() ) + stage.pipeline_root.hostname = "some.fake-host.name.com" stage.send_op_down = mocker.MagicMock() stage.send_event_up = mocker.MagicMock() return stage class MQTTTransportInstantiationTests(MQTTTransportStageTestConfig): - @pytest.mark.it("Initializes 'sas_token' attribute as None") - def test_sas_token(self, cls_type, init_kwargs): - stage = cls_type(**init_kwargs) - assert stage.sas_token is None - @pytest.mark.it("Initializes 'transport' attribute as None") def test_transport(self, cls_type, init_kwargs): stage = cls_type(**init_kwargs) @@ -100,28 +96,17 @@ def test_pending_op(self, cls_type, init_kwargs): ) -@pytest.mark.describe( - "MQTTTransportStage - .run_op() -- Called with SetMQTTConnectionArgsOperation" -) -class TestMQTTTransportStageRunOpCalledWithSetMQTTConnectionArgsOperation( +@pytest.mark.describe("MQTTTransportStage - .run_op() -- Called with InitializePipelineOperation") +class TestMQTTTransportStageRunOpCalledWithInitializePipelineOperation( MQTTTransportStageTestConfig, StageRunOpTestBase ): @pytest.fixture def op(self, mocker): - return pipeline_ops_mqtt.SetMQTTConnectionArgsOperation( - client_id="fake_client_id", - hostname="fake_hostname", - username="fake_username", - server_verification_cert="fake_server_verification_cert", - client_cert="fake_client_cert", - sas_token="fake_sas_token", - callback=mocker.MagicMock(), - ) - - @pytest.mark.it("Stores the sas_token operation in the 'sas_token' attribute of the stage") - def test_stores_data(self, stage, op, mocker, mock_transport): - stage.run_op(op) - assert stage.sas_token == op.sas_token + op = pipeline_ops_base.InitializePipelineOperation(callback=mocker.MagicMock()) + # These values are patched onto the op in a previous stage + op.client_id = "fake_client_id" + op.username = "fake_username" + return op @pytest.mark.it( "Creates an MQTTTransport object and sets it as the 'transport' attribute of the stage" @@ -152,13 +137,27 @@ def test_stores_data(self, stage, op, mocker, mock_transport): pytest.param("", id="Proxy Absent"), ], ) + @pytest.mark.parametrize( + "gateway_hostname", + [ + pytest.param("fake.gateway.hostname.com", id="Using Gateway Hostname"), + pytest.param(None, id="Not using Gateway Hostname"), + ], + ) def test_creates_transport( - self, mocker, stage, op, mock_transport, websockets, cipher, proxy_options + self, mocker, stage, op, mock_transport, websockets, cipher, proxy_options, gateway_hostname ): # Configure websockets & cipher stage.pipeline_root.pipeline_configuration.websockets = websockets stage.pipeline_root.pipeline_configuration.cipher = cipher stage.pipeline_root.pipeline_configuration.proxy_options = proxy_options + stage.pipeline_root.pipeline_configuration.gateway_hostname = gateway_hostname + + # NOTE: if more of this type of logic crops up, consider splitting this test up + if stage.pipeline_root.pipeline_configuration.gateway_hostname: + expected_hostname = stage.pipeline_root.pipeline_configuration.gateway_hostname + else: + expected_hostname = stage.pipeline_root.pipeline_configuration.hostname assert stage.transport is None @@ -167,10 +166,10 @@ def test_creates_transport( assert mock_transport.call_count == 1 assert mock_transport.call_args == mocker.call( client_id=op.client_id, - hostname=op.hostname, + hostname=expected_hostname, username=op.username, - server_verification_cert=op.server_verification_cert, - x509_cert=op.client_cert, + server_verification_cert=stage.pipeline_root.pipeline_configuration.server_verification_cert, + x509_cert=stage.pipeline_root.pipeline_configuration.x509, websockets=websockets, cipher=cipher, proxy_options=proxy_options, @@ -188,9 +187,11 @@ def test_sets_transport_handlers(self, mocker, stage, op, mock_transport): ) assert stage.transport.on_mqtt_message_received_handler == stage._on_mqtt_message_received - # CT-TODO: does this even need to be happening in this stage? Shouldn't this be part of init? @pytest.mark.it("Sets the stage's pending connection operation to None") - def test_pending_conn_op(self, stage, op, mock_transport): + def test_pending_conn_op(self, mocker, stage, op, mock_transport): + # NOTE: The pending connection operation ALREADY should be None, but we set it to None + # again for safety here just in case. So this test is for an edge case. + stage._pending_connection_op = mocker.MagicMock() stage.run_op(op) assert stage._pending_connection_op is None @@ -202,13 +203,13 @@ def test_succeeds(self, mocker, stage, op, mock_transport): # NOTE: The MQTTTransport object is not instantiated upon instantiation of the MQTTTransportStage. -# It is only added once the SetMQTTConnectionArgsOperation runs. +# It is only added once the InitializePipelineOperation runs. # The lifecycle of the MQTTTransportStage is as follows: # 1. Instantiate the stage -# 2. Configure the stage with a SetMQTTConnectionArgsOperation +# 2. Configure the stage with an InitializePipelineOperation # 3. Run any other desired operations. # -# This is to say, no operation should be running before SetMQTTConnectionArgsOperation. +# This is to say, no operation should be running before InitializePipelineOperation. # Thus, for the following tests, we will assume that the MQTTTransport has already been created, # and as such, the stage fixture used will have already have one. class MQTTTransportStageTestConfigComplex(MQTTTransportStageTestConfig): @@ -222,46 +223,16 @@ def stage(self, mocker, cls_type, init_kwargs, mock_transport): stage.send_event_up = mocker.MagicMock() # Set up the Transport on the stage - op = pipeline_ops_mqtt.SetMQTTConnectionArgsOperation( - client_id="fake_client_id", - hostname="fake_hostname", - username="fake_username", - server_verification_cert="fake_server_verification_cert", - client_cert="fake_client_cert", - sas_token="fake_sas_token", - callback=mocker.MagicMock(), - ) + op = pipeline_ops_base.InitializePipelineOperation(callback=mocker.MagicMock()) + op.client_id = "fake_client_id" + op.username = "fake_username" stage.run_op(op) + assert stage.transport is mock_transport.return_value return stage -@pytest.mark.describe("MQTTTransportStage - .run_op() -- Called with UpdateSasTokenOperation") -class TestMQTTTransportStageRunOpCalledWithUpdateSasTokenOperation( - MQTTTransportStageTestConfigComplex, StageRunOpTestBase -): - @pytest.fixture - def op(self, mocker): - return pipeline_ops_base.UpdateSasTokenOperation( - sas_token="new_fake_sas_token", callback=mocker.MagicMock() - ) - - @pytest.mark.it( - "Updates the 'sas_token' attribute to be the new value contained in the operation" - ) - def test_updates_token(self, stage, op): - assert stage.sas_token != op.sas_token - stage.run_op(op) - assert stage.sas_token == op.sas_token - - @pytest.mark.it("Completes the operation with success, upon successful execution") - def test_complets_op(self, stage, op): - assert not op.completed - stage.run_op(op) - assert op.completed - - @pytest.mark.describe("MQTTTransportStage - .run_op() -- Called with ConnectOperation") class TestMQTTTransportStageRunOpCalledWithConnectOperation( MQTTTransportStageTestConfigComplex, StageRunOpTestBase @@ -317,11 +288,26 @@ def test_starts_watchdog(self, mocker, stage, op, mock_timer): assert mock_timer.return_value.daemon is True assert mock_timer.return_value.start.call_count == 1 - @pytest.mark.it("Performs an MQTT connect via the MQTTTransport") - def test_mqtt_connect(self, mocker, stage, op): + @pytest.mark.it( + "Performs an MQTT connect via the MQTTTransport, using the root's SasToken as a password, if using SAS-based authentication" + ) + def test_mqtt_connect_sastoken(self, mocker, stage, op): + assert stage.pipeline_root.pipeline_configuration.sastoken is not None + stage.run_op(op) + assert stage.transport.connect.call_count == 1 + assert stage.transport.connect.call_args == mocker.call( + password=str(stage.pipeline_root.pipeline_configuration.sastoken) + ) + + @pytest.mark.it( + "Performs an MQTT connect via the MQTTTransport, with no password, if NOT using SAS-based authentication" + ) + def test_mqtt_connect_no_sastoken(self, mocker, stage, op): + # no token + stage.pipeline_root.pipeline_configuration.sastoken = None stage.run_op(op) assert stage.transport.connect.call_count == 1 - assert stage.transport.connect.call_args == mocker.call(password=stage.sas_token) + assert stage.transport.connect.call_args == mocker.call(password=None) @pytest.mark.it( "Completes the operation unsucessfully if there is a failure connecting via the MQTTTransport, using the error raised by the MQTTTransport" @@ -415,12 +401,14 @@ def test_starts_watchdog(self, mocker, stage, op, mock_timer): assert mock_timer.return_value.daemon is True assert mock_timer.return_value.start.call_count == 1 - @pytest.mark.it("Performs an MQTT reconnect via the MQTTTransport") + @pytest.mark.it( + "Performs an MQTT reconnect via the MQTTTransport, using the pipeline root's SasToken as a password" + ) def test_mqtt_connect(self, mocker, stage, op): stage.run_op(op) assert stage.transport.reauthorize_connection.call_count == 1 assert stage.transport.reauthorize_connection.call_args == mocker.call( - password=stage.sas_token + password=str(stage.pipeline_root.pipeline_configuration.sastoken) ) @pytest.mark.it( diff --git a/azure-iot-device/tests/common/test_sastoken.py b/azure-iot-device/tests/common/test_sastoken.py deleted file mode 100644 index 479d929e4..000000000 --- a/azure-iot-device/tests/common/test_sastoken.py +++ /dev/null @@ -1,134 +0,0 @@ -# -*- coding: utf-8 -*- -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import pytest -import time -import base64 -import hmac -import hashlib -import copy -import logging -import six.moves.urllib as urllib -from azure.iot.device.common.sastoken import SasToken, SasTokenError - -logging.basicConfig(level=logging.DEBUG) - -uri = "my.host.name" -key = "Zm9vYmFy" -key_name = "mykeyname" -device_token_kwargs = {"uri": uri, "key": key} -service_token_kwargs = {"uri": uri, "key": key, "key_name": key_name} - - -def generate_signature(uri, key, expiry_time): - message = (uri + "\n" + str(expiry_time)).encode("utf-8") - signing_key = base64.b64decode(key.encode("utf-8")) - signed_hmac = hmac.HMAC(signing_key, message, hashlib.sha256) - signature = urllib.parse.quote(base64.b64encode(signed_hmac.digest())) - return signature - - -@pytest.fixture(params=["Device Token", "Service Token"]) -def sastoken(request): - token_type = request.param - if token_type == "Device Token": - return SasToken(uri, key) - elif token_type == "Service Token": - return SasToken(uri, key, key_name) - - -@pytest.mark.describe("SasToken") -class TestSasToken(object): - @pytest.mark.it("Instantiates with default TTL of 3600 seconds") - @pytest.mark.parametrize( - "kwargs", - [ - pytest.param(device_token_kwargs, id="Device Token"), - pytest.param(service_token_kwargs, id="Service Token"), - ], - ) - def test_instantiates_with_default_ttl_3600(self, kwargs): - s = SasToken(**kwargs) - assert s._uri == kwargs.get("uri") - assert s._key == kwargs.get("key") - assert s._key_name == kwargs.get("key_name") - assert s.ttl == 3600 - - @pytest.mark.it("Instantiates with a custom TTL") - @pytest.mark.parametrize( - "kwargs", - [ - pytest.param(device_token_kwargs, id="Device Token"), - pytest.param(service_token_kwargs, id="Service Token"), - ], - ) - def test_instantiates_with_custom_ttl(self, kwargs): - kwargs = copy.copy(kwargs) - kwargs["ttl"] = 9000 - s = SasToken(**kwargs) - assert s._uri == kwargs.get("uri") - assert s._key == kwargs.get("key") - assert s._key_name == kwargs.get("key_name") - assert s.ttl == 9000 - - @pytest.mark.it("URL encodes UTF-8 characters in provided URI") - def test_url_encodes_utf8_characters_in_uri(self): - utf8_uri = "my châteu.host.name" - s = SasToken(utf8_uri, key) - - expected_uri = "my+ch%C3%A2teu.host.name" - assert s._uri == expected_uri - - @pytest.mark.it("Raises SasTokenError if provided a key that is not base64 encoded") - def test_raises_sastoken_error_if_key_is_not_base64(self): - non_b64_key = "this is not base64" - with pytest.raises(SasTokenError): - SasToken(uri, non_b64_key) - - @pytest.mark.it("Converting object to string returns the SasToken string") - @pytest.mark.parametrize( - "sastoken,token_pattern", - [ - pytest.param( - "Device Token", "SharedAccessSignature sr={}&sig={}&se={}", id="Device Token" - ), - pytest.param( - "Service Token", - "SharedAccessSignature sr={}&sig={}&se={}&skn={}", - id="Service Token", - ), - ], - indirect=["sastoken"], - ) - def test_string_conversion_returns_expected_sastoken_string(self, sastoken, token_pattern): - signature = generate_signature(sastoken._uri, sastoken._key, sastoken.expiry_time) - if sastoken._key_name: - expected_string = token_pattern.format( - sastoken._uri, signature, sastoken.expiry_time, sastoken._key_name - ) - else: - expected_string = token_pattern.format(sastoken._uri, signature, sastoken.expiry_time) - strrep = str(sastoken) - assert strrep == expected_string - - @pytest.mark.it("Can be refreshed to extend the expiry time by the TTL") - def test_refreshing_token_sets_expiry_time_to_be_ttl_seconds_in_the_future( - self, mocker, sastoken - ): - current_time = 1000 - mocker.patch.object(time, "time", return_value=current_time) - sastoken.refresh() - assert sastoken.expiry_time == current_time + sastoken.ttl - - @pytest.mark.it("Updates SasToken string upon refresh") - def test_refreshing_token_changes_string_representation(self, sastoken): - # This should happen because refreshing updates expiry time - old_token_string = str(sastoken) - time.sleep(1) - sastoken.refresh() - new_token_string = str(sastoken) - assert old_token_string != new_token_string diff --git a/azure-iot-device/tests/iothub/auth/__init__.py b/azure-iot-device/tests/iothub/aio/__init__.py similarity index 100% rename from azure-iot-device/tests/iothub/auth/__init__.py rename to azure-iot-device/tests/iothub/aio/__init__.py diff --git a/azure-iot-device/tests/iothub/aio/test_async_clients.py b/azure-iot-device/tests/iothub/aio/test_async_clients.py index 8e140a89b..057bfefe8 100644 --- a/azure-iot-device/tests/iothub/aio/test_async_clients.py +++ b/azure-iot-device/tests/iothub/aio/test_async_clients.py @@ -11,16 +11,25 @@ import time import os import io +import sys from azure.iot.device import exceptions as client_exceptions from azure.iot.device.iothub.aio import IoTHubDeviceClient, IoTHubModuleClient -from azure.iot.device.iothub.pipeline import MQTTPipeline, constant +from azure.iot.device.iothub.pipeline import constant from azure.iot.device.iothub.pipeline import exceptions as pipeline_exceptions from azure.iot.device.iothub.models import Message, MethodRequest from azure.iot.device.iothub.aio.async_inbox import AsyncClientInbox from azure.iot.device.common import async_adapter -from azure.iot.device.iothub.auth import IoTEdgeError -import sys from azure.iot.device import constant as device_constant +from ..shared_client_tests import ( + SharedIoTHubClientInstantiationTests, + SharedIoTHubClientPROPERTYConnectedTests, + SharedIoTHubClientCreateFromConnectionStringTests, + SharedIoTHubDeviceClientCreateFromSymmetricKeyTests, + SharedIoTHubDeviceClientCreateFromX509CertificateTests, + SharedIoTHubModuleClientCreateFromX509CertificateTests, + SharedIoTHubModuleClientCreateFromEdgeEnvironmentWithContainerEnvTests, + SharedIoTHubModuleClientCreateFromEdgeEnvironmentWithDebugEnvTests, +) pytestmark = pytest.mark.asyncio logging.basicConfig(level=logging.DEBUG) @@ -32,278 +41,9 @@ async def create_completed_future(result=None): return f -# automatically mock the mqtt pipeline for all tests in this file. -@pytest.fixture(autouse=True) -def mock_mqtt_pipeline_init(mocker): - return mocker.patch("azure.iot.device.iothub.pipeline.MQTTPipeline") - - -# automatically mock the http pipeline for all tests in this file. -@pytest.fixture(autouse=True) -def mock_http_pipeline_init(mocker): - return mocker.patch("azure.iot.device.iothub.pipeline.HTTPPipeline") - - -class SharedClientInstantiationTests(object): - @pytest.mark.it( - "Stores the MQTTPipeline from the 'mqtt_pipeline' parameter in the '_mqtt_pipeline' attribute" - ) - async def test_mqtt_pipeline_attribute(self, client_class, mqtt_pipeline, http_pipeline): - client = client_class(mqtt_pipeline, http_pipeline) - - assert client._mqtt_pipeline is mqtt_pipeline - - @pytest.mark.it( - "Stores the HTTPPipeline from the 'http_pipeline' parameter in the '_http_pipeline' attribute" - ) - async def test_sets_http_pipeline_attribute(self, client_class, mqtt_pipeline, http_pipeline): - client = client_class(mqtt_pipeline, http_pipeline) - - assert client._http_pipeline is http_pipeline - - @pytest.mark.it("Sets on_connected handler in the MQTTPipeline") - async def test_sets_on_connected_handler_in_pipeline( - self, client_class, mqtt_pipeline, http_pipeline - ): - client = client_class(mqtt_pipeline, http_pipeline) - - assert client._mqtt_pipeline.on_connected is not None - assert client._mqtt_pipeline.on_connected == client._on_connected - - @pytest.mark.it("Sets on_disconnected handler in the MQTTPipeline") - async def test_sets_on_disconnected_handler_in_pipeline( - self, client_class, mqtt_pipeline, http_pipeline - ): - client = client_class(mqtt_pipeline, http_pipeline) - - assert client._mqtt_pipeline.on_disconnected is not None - assert client._mqtt_pipeline.on_disconnected == client._on_disconnected - - @pytest.mark.it("Sets on_method_request_received handler in the MQTTPipeline") - async def test_sets_on_method_request_received_handler_in_pipleline( - self, client_class, mqtt_pipeline, http_pipeline - ): - client = client_class(mqtt_pipeline, http_pipeline) - - assert client._mqtt_pipeline.on_method_request_received is not None - assert ( - client._mqtt_pipeline.on_method_request_received - == client._inbox_manager.route_method_request - ) - - -class SharedClientCreateMethodUserOptionTests(object): - # In these tests we patch the entire 'auth' library instead of specific auth providers in order - # to make them more generic, and applicable across all creation methods. - - @pytest.fixture - def option_test_required_patching(self, mocker): - """Override this fixture in a subclass if unique patching is required""" - pass - - @pytest.mark.it( - "Sets the 'product_info' user option parameter on the PipelineConfig, if provided" - ) - async def test_product_info_option( - self, - option_test_required_patching, - client_create_method, - create_method_args, - mock_mqtt_pipeline_init, - mock_http_pipeline_init, - ): - product_info = "MyProductInfo" - client_create_method(*create_method_args, product_info=product_info) - - # Get configuration object, and ensure it was used for both protocol pipelines - assert mock_mqtt_pipeline_init.call_count == 1 - config = mock_mqtt_pipeline_init.call_args[0][1] - assert config == mock_http_pipeline_init.call_args[0][1] - - assert config.product_info == product_info - - @pytest.mark.it( - "Sets the 'websockets' user option parameter on the PipelineConfig, if provided" - ) - async def test_websockets_option( - self, - option_test_required_patching, - client_create_method, - create_method_args, - mock_mqtt_pipeline_init, - mock_http_pipeline_init, - ): - client_create_method(*create_method_args, websockets=True) - - # Get configuration object, and ensure it was used for both protocol pipelines - assert mock_mqtt_pipeline_init.call_count == 1 - config = mock_mqtt_pipeline_init.call_args[0][1] - assert config == mock_http_pipeline_init.call_args[0][1] - - assert config.websockets - - @pytest.mark.it("Sets the 'cipher' user option parameter on the PipelineConfig, if provided") - async def test_cipher_option( - self, - option_test_required_patching, - client_create_method, - create_method_args, - mock_mqtt_pipeline_init, - mock_http_pipeline_init, - ): - cipher = "DHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA:ECDHE-ECDSA-AES128-GCM-SHA256" - client_create_method(*create_method_args, cipher=cipher) - - # Get configuration object, and ensure it was used for both protocol pipelines - assert mock_mqtt_pipeline_init.call_count == 1 - config = mock_mqtt_pipeline_init.call_args[0][1] - assert config == mock_http_pipeline_init.call_args[0][1] - - assert config.cipher == cipher - - @pytest.mark.it( - "Sets the 'server_verification_cert' user option parameter on the AuthenticationProvider, if provided" - ) - async def test_server_verification_cert_option( - self, - option_test_required_patching, - client_create_method, - create_method_args, - mock_mqtt_pipeline_init, - mock_http_pipeline_init, - ): - server_verification_cert = "fake_server_verification_cert" - client_create_method(*create_method_args, server_verification_cert=server_verification_cert) - - # Get auth provider object, and ensure it was used for both protocol pipelines - auth = mock_mqtt_pipeline_init.call_args[0][0] - assert auth == mock_http_pipeline_init.call_args[0][0] - - assert auth.server_verification_cert == server_verification_cert - - @pytest.mark.it("Raises a TypeError if an invalid user option parameter is provided") - async def test_invalid_option( - self, - option_test_required_patching, - client_create_method, - create_method_args, - mock_mqtt_pipeline_init, - mock_http_pipeline_init, - ): - with pytest.raises(TypeError): - client_create_method(*create_method_args, invalid_option="some_value") - - @pytest.mark.it("Sets default user options if none are provided") - async def test_default_options( - self, - mocker, - option_test_required_patching, - client_create_method, - create_method_args, - mock_mqtt_pipeline_init, - mock_http_pipeline_init, - ): - mock_config = mocker.patch("azure.iot.device.iothub.pipeline.IoTHubPipelineConfig") - - client_create_method(*create_method_args) - - # Pipeline Config was instantiated with default arguments - assert mock_config.call_count == 1 - expected_kwargs = {} - assert mock_config.call_args == mocker.call(**expected_kwargs) - - # This default config was used for both protocol pipelines - assert mock_mqtt_pipeline_init.call_count == 1 - assert mock_mqtt_pipeline_init.call_args[0][1] == mock_config.return_value - assert mock_http_pipeline_init.call_args[0][1] == mock_config.return_value - - # Get auth provider object, and ensure it was used for both protocol pipelines - auth = mock_mqtt_pipeline_init.call_args[0][0] - assert auth == mock_http_pipeline_init.call_args[0][0] - - # Ensure that auth options are set to expected defaults - assert auth.server_verification_cert is None - - -class SharedClientCreateFromConnectionStringTests(object): - @pytest.mark.it("Uses the connection string to create a SymmetricKeyAuthenticationProvider") - async def test_auth_provider_creation(self, mocker, client_class, connection_string): - mock_auth_parse = mocker.patch( - "azure.iot.device.iothub.auth.SymmetricKeyAuthenticationProvider" - ).parse - mocker.patch("azure.iot.device.iothub.pipeline.IoTHubPipelineConfig") - - client_class.create_from_connection_string(connection_string) - - assert mock_auth_parse.call_count == 1 - assert mock_auth_parse.call_args == mocker.call(connection_string) - - @pytest.mark.it("Uses the SymmetricKeyAuthenticationProvider to create an MQTTPipeline") - @pytest.mark.parametrize( - "server_verification_cert", - [ - pytest.param(None, id="No Server Verification Certificate"), - pytest.param("some-certificate", id="With Server Verification Certificate"), - ], - ) - async def test_pipeline_creation( - self, - mocker, - client_class, - connection_string, - server_verification_cert, - mock_mqtt_pipeline_init, - ): - mock_auth = mocker.patch( - "azure.iot.device.iothub.auth.SymmetricKeyAuthenticationProvider" - ).parse.return_value - - mock_config_init = mocker.patch("azure.iot.device.iothub.pipeline.IoTHubPipelineConfig") - - client_class.create_from_connection_string(connection_string) - - assert mock_mqtt_pipeline_init.call_count == 1 - assert mock_mqtt_pipeline_init.call_args == mocker.call( - mock_auth, mock_config_init.return_value - ) - - @pytest.mark.it("Uses the MQTTPipeline to instantiate the client") - async def test_client_instantiation(self, mocker, client_class, connection_string): - mock_pipeline = mocker.patch("azure.iot.device.iothub.pipeline.MQTTPipeline").return_value - mock_pipeline_http = mocker.patch( - "azure.iot.device.iothub.pipeline.HTTPPipeline" - ).return_value - spy_init = mocker.spy(client_class, "__init__") - - client_class.create_from_connection_string(connection_string) - - assert spy_init.call_count == 1 - assert spy_init.call_args == mocker.call(mocker.ANY, mock_pipeline, mock_pipeline_http) - - @pytest.mark.it("Returns the instantiated client") - async def test_returns_client(self, client_class, connection_string): - client = client_class.create_from_connection_string(connection_string) - - assert isinstance(client, client_class) - - # TODO: If auth package was refactored to use ConnectionString class, tests from that - # class would increase the coverage here. - @pytest.mark.it("Raises ValueError when given an invalid connection string") - @pytest.mark.parametrize( - "bad_cs", - [ - pytest.param("not-a-connection-string", id="Garbage string"), - pytest.param(object(), id="Non-string input"), - pytest.param( - "HostName=Invalid;DeviceId=Invalid;SharedAccessKey=Invalid", - id="Malformed Connection String", - marks=pytest.mark.xfail(reason="Bug in pipeline + need for auth refactor"), # TODO - ), - ], - ) - async def test_raises_value_error_on_bad_connection_string(self, client_class, bad_cs): - with pytest.raises(ValueError): - client_class.create_from_connection_string(bad_cs) +####################### +# SHARED CLIENT TESTS # +####################### class SharedClientConnectTests(object): @@ -949,20 +689,6 @@ async def test_returns_message_from_twin_patch_inbox(self, mocker, client, twin_ assert received_patch is twin_patch_desired -class SharedClientPROPERTYConnectedTests(object): - @pytest.mark.it("Cannot be changed") - async def test_read_only(self, client): - with pytest.raises(AttributeError): - client.connected = not client.connected - - @pytest.mark.it("Reflects the value of the root stage property of the same name") - async def test_reflects_pipeline_property(self, client, mqtt_pipeline): - mqtt_pipeline.connected = True - assert client.connected - mqtt_pipeline.connected = False - assert not client.connected - - ################ # DEVICE TESTS # ################ @@ -992,7 +718,7 @@ def sas_token_string(self, device_sas_token_string): @pytest.mark.describe("IoTHubDeviceClient (Asynchronous) - Instantiation") class TestIoTHubDeviceClientInstantiation( - IoTHubDeviceClientTestsConfig, SharedClientInstantiationTests + IoTHubDeviceClientTestsConfig, SharedIoTHubClientInstantiationTests ): @pytest.mark.it("Sets on_c2d_message_received handler in the MQTTPipeline") async def test_sets_on_c2d_message_received_handler_in_pipeline( @@ -1008,105 +734,23 @@ async def test_sets_on_c2d_message_received_handler_in_pipeline( @pytest.mark.describe("IoTHubDeviceClient (Asynchronous) - .create_from_connection_string()") class TestIoTHubDeviceClientCreateFromConnectionString( - IoTHubDeviceClientTestsConfig, - SharedClientCreateFromConnectionStringTests, - SharedClientCreateMethodUserOptionTests, + IoTHubDeviceClientTestsConfig, SharedIoTHubClientCreateFromConnectionStringTests ): - @pytest.fixture - def client_create_method(self, client_class): - """Provides the specific create method for use in universal tests""" - return client_class.create_from_connection_string - - @pytest.fixture - def create_method_args(self, connection_string): - """Provides the specific create method args for use in universal tests""" - return [connection_string] + pass @pytest.mark.describe("IoTHubDeviceClient (Asynchronous) - .create_from_symmetric_key()") class TestConfigurationCreateIoTHubDeviceClientFromSymmetricKey( - IoTHubDeviceClientTestsConfig, SharedClientCreateMethodUserOptionTests + IoTHubDeviceClientTestsConfig, SharedIoTHubDeviceClientCreateFromSymmetricKeyTests ): - @pytest.fixture - def client_create_method(self, client_class): - """Provides the specific create method for use in universal tests""" - return client_class.create_from_symmetric_key - - @pytest.fixture - def create_method_args(self, symmetric_key, hostname_fixture, device_id_fixture): - """Provides the specific create method args for use in universal tests""" - return [symmetric_key, hostname_fixture, device_id_fixture] + pass @pytest.mark.describe("IoTHubDeviceClient (Asynchronous) - .create_from_x509_certificate()") class TestIoTHubDeviceClientCreateFromX509Certificate( - IoTHubDeviceClientTestsConfig, SharedClientCreateMethodUserOptionTests + IoTHubDeviceClientTestsConfig, SharedIoTHubDeviceClientCreateFromX509CertificateTests ): - hostname = "durmstranginstitute.farend" - device_id = "MySnitch" - - @pytest.fixture - def client_create_method(self, client_class): - """Provides the specific create method for use in universal tests""" - return client_class.create_from_x509_certificate - - @pytest.fixture - def create_method_args(self, x509): - """Provides the specific create method args for use in universal tests""" - return [x509, self.hostname, self.device_id] - - @pytest.mark.it("Uses the provided arguments to create a X509AuthenticationProvider") - async def test_auth_provider_creation(self, mocker, client_class, x509): - mock_auth_init = mocker.patch("azure.iot.device.iothub.auth.X509AuthenticationProvider") - - client_class.create_from_x509_certificate( - x509=x509, hostname=self.hostname, device_id=self.device_id - ) - - assert mock_auth_init.call_count == 1 - assert mock_auth_init.call_args == mocker.call( - x509=x509, hostname=self.hostname, device_id=self.device_id - ) - - @pytest.mark.it("Uses the X509AuthenticationProvider to create an MQTTPipeline") - async def test_pipeline_creation(self, mocker, client_class, x509, mock_mqtt_pipeline_init): - mock_auth = mocker.patch( - "azure.iot.device.iothub.auth.X509AuthenticationProvider" - ).return_value - - mock_config = mocker.patch( - "azure.iot.device.iothub.pipeline.IoTHubPipelineConfig" - ).return_value - - client_class.create_from_x509_certificate( - x509=x509, hostname=self.hostname, device_id=self.device_id - ) - - assert mock_mqtt_pipeline_init.call_count == 1 - assert mock_mqtt_pipeline_init.call_args == mocker.call(mock_auth, mock_config) - - @pytest.mark.it("Uses the MQTTPipeline to instantiate the client") - async def test_client_instantiation(self, mocker, client_class, x509): - mock_pipeline = mocker.patch("azure.iot.device.iothub.pipeline.MQTTPipeline").return_value - mock_pipeline_http = mocker.patch( - "azure.iot.device.iothub.pipeline.HTTPPipeline" - ).return_value - spy_init = mocker.spy(client_class, "__init__") - - client_class.create_from_x509_certificate( - x509=x509, hostname=self.hostname, device_id=self.device_id - ) - - assert spy_init.call_count == 1 - assert spy_init.call_args == mocker.call(mocker.ANY, mock_pipeline, mock_pipeline_http) - - @pytest.mark.it("Returns the instantiated client") - async def test_returns_client(self, mocker, client_class, x509): - client = client_class.create_from_x509_certificate( - x509=x509, hostname=self.hostname, device_id=self.device_id - ) - - assert isinstance(client, client_class) + pass @pytest.mark.describe("IoTHubDeviceClient (Asynchronous) - .connect()") @@ -1353,7 +997,7 @@ def fail_notify_blob_upload_status( @pytest.mark.describe("IoTHubDeviceClient (Asynchronous) - PROPERTY .connected") class TestIoTHubDeviceClientPROPERTYConnected( - IoTHubDeviceClientTestsConfig, SharedClientPROPERTYConnectedTests + IoTHubDeviceClientTestsConfig, SharedIoTHubClientPROPERTYConnectedTests ): pass @@ -1387,7 +1031,7 @@ def sas_token_string(self, module_sas_token_string): @pytest.mark.describe("IoTHubModuleClient (Asynchronous) - Instantiation") class TestIoTHubModuleClientInstantiation( - IoTHubModuleClientTestsConfig, SharedClientInstantiationTests + IoTHubModuleClientTestsConfig, SharedIoTHubClientInstantiationTests ): @pytest.mark.it("Sets on_input_message_received handler in the MQTTPipeline") async def test_sets_on_input_message_received_handler_in_pipeline( @@ -1404,511 +1048,36 @@ async def test_sets_on_input_message_received_handler_in_pipeline( @pytest.mark.describe("IoTHubModuleClient (Asynchronous) - .create_from_connection_string()") class TestIoTHubModuleClientCreateFromConnectionString( - IoTHubModuleClientTestsConfig, - SharedClientCreateFromConnectionStringTests, - SharedClientCreateMethodUserOptionTests, -): - @pytest.fixture - def client_create_method(self, client_class): - """Provides the specific create method for use in universal tests""" - return client_class.create_from_connection_string - - @pytest.fixture - def create_method_args(self, connection_string): - """Provides the specific create method args for use in universal tests""" - return [connection_string] - - -class IoTHubModuleClientClientCreateFromEdgeEnvironmentUserOptionTests( - SharedClientCreateMethodUserOptionTests + IoTHubModuleClientTestsConfig, SharedIoTHubClientCreateFromConnectionStringTests ): - """This class inherites the user option tests shared by all create method APIs, and overrides - tests in order to accomodate unique requirements for the .create_from_edge_enviornment() method. - - Because .create_from_edge_environment() tests are spread accross multiple test units - (i.e. test classes), these overrides are done in this class, which is then inherited by all - .create_from_edge_environment() test units below. - """ - - @pytest.fixture - def client_create_method(self, client_class): - """Provides the specific create method for use in universal tests""" - return client_class.create_from_edge_environment - - @pytest.fixture - def create_method_args(self): - """Provides the specific create method args for use in universal tests""" - return [] - - @pytest.mark.it( - "Raises a TypeError if the 'server_verification_cert' user option parameter is provided" - ) - async def test_server_verification_cert_option( - self, - option_test_required_patching, - client_create_method, - create_method_args, - mock_mqtt_pipeline_init, - mock_http_pipeline_init, - ): - """THIS TEST OVERRIDES AN INHERITED TEST""" - - with pytest.raises(TypeError): - client_create_method( - *create_method_args, server_verification_cert="fake_server_verification_cert" - ) - - @pytest.mark.it("Sets default user options if none are provided") - async def test_default_options( - self, - mocker, - option_test_required_patching, - client_create_method, - create_method_args, - mock_mqtt_pipeline_init, - mock_http_pipeline_init, - ): - """THIS TEST OVERRIDES AN INHERITED TEST""" - mock_config = mocker.patch("azure.iot.device.iothub.pipeline.IoTHubPipelineConfig") - - client_create_method(*create_method_args) - - # Pipeline Config was instantiated with default arguments - assert mock_config.call_count == 1 - expected_kwargs = {} - assert mock_config.call_args == mocker.call(**expected_kwargs) - - # This default config was used for both protocol pipelines - assert mock_mqtt_pipeline_init.call_count == 1 - assert mock_mqtt_pipeline_init.call_args[0][1] == mock_config.return_value - assert mock_http_pipeline_init.call_args[0][1] == mock_config.return_value - - # Get auth provider object, and ensure it was used for both protocol pipelines - auth = mock_mqtt_pipeline_init.call_args[0][0] - assert auth == mock_http_pipeline_init.call_args[0][0] + pass @pytest.mark.describe( "IoTHubModuleClient (Asynchronous) - .create_from_edge_environment() -- Edge Container Environment" ) class TestIoTHubModuleClientCreateFromEdgeEnvironmentWithContainerEnv( - IoTHubModuleClientTestsConfig, IoTHubModuleClientClientCreateFromEdgeEnvironmentUserOptionTests + IoTHubModuleClientTestsConfig, + SharedIoTHubModuleClientCreateFromEdgeEnvironmentWithContainerEnvTests, ): - @pytest.fixture - def option_test_required_patching(self, mocker, edge_container_environment): - """THIS FIXTURE OVERRIDES AN INHERITED FIXTURE""" - mocker.patch("azure.iot.device.iothub.auth.IoTEdgeAuthenticationProvider") - mocker.patch.dict(os.environ, edge_container_environment, clear=True) - - @pytest.mark.it( - "Uses Edge container environment variables to create an IoTEdgeAuthenticationProvider" - ) - async def test_auth_provider_creation(self, mocker, client_class, edge_container_environment): - mocker.patch.dict(os.environ, edge_container_environment, clear=True) - mock_auth_init = mocker.patch("azure.iot.device.iothub.auth.IoTEdgeAuthenticationProvider") - - client_class.create_from_edge_environment() - - assert mock_auth_init.call_count == 1 - assert mock_auth_init.call_args == mocker.call( - hostname=edge_container_environment["IOTEDGE_IOTHUBHOSTNAME"], - device_id=edge_container_environment["IOTEDGE_DEVICEID"], - module_id=edge_container_environment["IOTEDGE_MODULEID"], - gateway_hostname=edge_container_environment["IOTEDGE_GATEWAYHOSTNAME"], - module_generation_id=edge_container_environment["IOTEDGE_MODULEGENERATIONID"], - workload_uri=edge_container_environment["IOTEDGE_WORKLOADURI"], - api_version=edge_container_environment["IOTEDGE_APIVERSION"], - ) - - @pytest.mark.it( - "Ignores any Edge local debug environment variables that may be present, in favor of using Edge container variables" - ) - async def test_auth_provider_creation_hybrid_env( - self, mocker, client_class, edge_container_environment, edge_local_debug_environment - ): - # This test verifies that with a hybrid environment, the auth provider will always be - # an IoTEdgeAuthenticationProvider, even if local debug variables are present - hybrid_environment = {**edge_container_environment, **edge_local_debug_environment} - mocker.patch.dict(os.environ, hybrid_environment, clear=True) - mock_edge_auth_init = mocker.patch( - "azure.iot.device.iothub.auth.IoTEdgeAuthenticationProvider" - ) - mock_sk_auth_parse = mocker.patch( - "azure.iot.device.iothub.auth.SymmetricKeyAuthenticationProvider" - ).parse - - client_class.create_from_edge_environment() - - assert mock_edge_auth_init.call_count == 1 - assert mock_sk_auth_parse.call_count == 0 # we did NOT use SK auth - assert mock_edge_auth_init.call_args == mocker.call( - hostname=edge_container_environment["IOTEDGE_IOTHUBHOSTNAME"], - device_id=edge_container_environment["IOTEDGE_DEVICEID"], - module_id=edge_container_environment["IOTEDGE_MODULEID"], - gateway_hostname=edge_container_environment["IOTEDGE_GATEWAYHOSTNAME"], - module_generation_id=edge_container_environment["IOTEDGE_MODULEGENERATIONID"], - workload_uri=edge_container_environment["IOTEDGE_WORKLOADURI"], - api_version=edge_container_environment["IOTEDGE_APIVERSION"], - ) - - @pytest.mark.it( - "Uses the IoTEdgeAuthenticationProvider to create an MQTTPipeline and an HTTPPipeline" - ) - async def test_pipeline_creation(self, mocker, client_class, edge_container_environment): - mocker.patch.dict(os.environ, edge_container_environment, clear=True) - mock_auth = mocker.patch( - "azure.iot.device.iothub.auth.IoTEdgeAuthenticationProvider" - ).return_value - mock_config = mocker.patch( - "azure.iot.device.iothub.pipeline.IoTHubPipelineConfig" - ).return_value - - mock_mqtt_pipeline_init = mocker.patch("azure.iot.device.iothub.pipeline.MQTTPipeline") - mock_http_pipeline_init = mocker.patch("azure.iot.device.iothub.pipeline.HTTPPipeline") - - client_class.create_from_edge_environment() - - assert mock_mqtt_pipeline_init.call_count == 1 - assert mock_mqtt_pipeline_init.call_args == mocker.call(mock_auth, mock_config) - assert mock_http_pipeline_init.call_count == 1 - assert mock_http_pipeline_init.call_args == mocker.call(mock_auth, mock_config) - - @pytest.mark.it("Uses the MQTTPipeline and the HTTPPipeline to instantiate the client") - async def test_client_instantiation(self, mocker, client_class, edge_container_environment): - mocker.patch.dict(os.environ, edge_container_environment, clear=True) - # Always patch the IoTEdgeAuthenticationProvider to prevent I/O operations - mocker.patch("azure.iot.device.iothub.auth.IoTEdgeAuthenticationProvider") - mock_mqtt_pipeline = mocker.patch( - "azure.iot.device.iothub.pipeline.MQTTPipeline" - ).return_value - mock_http_pipeline = mocker.patch( - "azure.iot.device.iothub.pipeline.HTTPPipeline" - ).return_value - spy_init = mocker.spy(client_class, "__init__") - - client_class.create_from_edge_environment() - - assert spy_init.call_count == 1 - assert spy_init.call_args == mocker.call(mocker.ANY, mock_mqtt_pipeline, mock_http_pipeline) - - @pytest.mark.it("Returns the instantiated client") - async def test_returns_client(self, mocker, client_class, edge_container_environment): - mocker.patch.dict(os.environ, edge_container_environment, clear=True) - # Always patch the IoTEdgeAuthenticationProvider to prevent I/O operations - mocker.patch("azure.iot.device.iothub.auth.IoTEdgeAuthenticationProvider") - - client = client_class.create_from_edge_environment() - - assert isinstance(client, client_class) - - @pytest.mark.it("Raises OSError if the environment is missing required variables") - @pytest.mark.parametrize( - "missing_env_var", - [ - "IOTEDGE_MODULEID", - "IOTEDGE_DEVICEID", - "IOTEDGE_IOTHUBHOSTNAME", - "IOTEDGE_GATEWAYHOSTNAME", - "IOTEDGE_APIVERSION", - "IOTEDGE_MODULEGENERATIONID", - "IOTEDGE_WORKLOADURI", - ], - ) - async def test_bad_environment( - self, mocker, client_class, edge_container_environment, missing_env_var - ): - # Remove a variable from the fixture - del edge_container_environment[missing_env_var] - mocker.patch.dict(os.environ, edge_container_environment, clear=True) - - with pytest.raises(OSError): - client_class.create_from_edge_environment() - - @pytest.mark.it("Raises OSError if there is an error using the Edge for authentication") - async def test_bad_edge_auth(self, mocker, client_class, edge_container_environment): - mocker.patch.dict(os.environ, edge_container_environment, clear=True) - mock_auth = mocker.patch("azure.iot.device.iothub.auth.IoTEdgeAuthenticationProvider") - error = IoTEdgeError() - mock_auth.side_effect = error - with pytest.raises(OSError) as e_info: - client_class.create_from_edge_environment() - assert e_info.value.__cause__ is error + pass @pytest.mark.describe( "IoTHubModuleClient (Asynchronous) - .create_from_edge_environment() -- Edge Local Debug Environment" ) class TestIoTHubModuleClientCreateFromEdgeEnvironmentWithDebugEnv( - IoTHubModuleClientTestsConfig, IoTHubModuleClientClientCreateFromEdgeEnvironmentUserOptionTests + IoTHubModuleClientTestsConfig, + SharedIoTHubModuleClientCreateFromEdgeEnvironmentWithDebugEnvTests, ): - @pytest.fixture - def option_test_required_patching(self, mocker, edge_local_debug_environment): - """THIS FIXTURE OVERRIDES AN INHERITED FIXTURE""" - mocker.patch("azure.iot.device.iothub.auth.SymmetricKeyAuthenticationProvider") - mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) - mocker.patch.object(io, "open") - - @pytest.fixture - def mock_open(self, mocker): - return mocker.patch.object(io, "open") - - @pytest.mark.it( - "Extracts the server verification certificate from the file indicated by the EdgeModuleCACertificateFile environment variable" - ) - async def test_read_server_verification_cert( - self, mocker, client_class, edge_local_debug_environment, mock_open - ): - mock_file_handle = mock_open.return_value.__enter__.return_value - mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) - client_class.create_from_edge_environment() - assert mock_open.call_count == 1 - assert mock_open.call_args == mocker.call( - edge_local_debug_environment["EdgeModuleCACertificateFile"], mode="r" - ) - assert mock_file_handle.read.call_count == 1 - - @pytest.mark.it( - "Uses Edge local debug environment variables to create a SymmetricKeyAuthenticationProvider (with server verification cert)" - ) - async def test_auth_provider_creation( - self, mocker, client_class, edge_local_debug_environment, mock_open - ): - expected_cert = mock_open.return_value.__enter__.return_value.read.return_value - mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) - mock_auth_parse = mocker.patch( - "azure.iot.device.iothub.auth.SymmetricKeyAuthenticationProvider" - ).parse - - client_class.create_from_edge_environment() - - assert mock_auth_parse.call_count == 1 - assert mock_auth_parse.call_args == mocker.call( - edge_local_debug_environment["EdgeHubConnectionString"] - ) - assert mock_auth_parse.return_value.server_verification_cert == expected_cert - - @pytest.mark.it( - "Only uses Edge local debug variables if no Edge container variables are present in the environment" - ) - async def test_auth_provider_and_pipeline_hybrid_env( - self, - mocker, - client_class, - edge_container_environment, - edge_local_debug_environment, - mock_open, - ): - # This test verifies that with a hybrid environment, the auth provider will always be - # an IoTEdgeAuthenticationProvider, even if local debug variables are present - hybrid_environment = {**edge_container_environment, **edge_local_debug_environment} - mocker.patch.dict(os.environ, hybrid_environment, clear=True) - mock_edge_auth_init = mocker.patch( - "azure.iot.device.iothub.auth.IoTEdgeAuthenticationProvider" - ) - mock_sk_auth_parse = mocker.patch( - "azure.iot.device.iothub.auth.SymmetricKeyAuthenticationProvider" - ).parse - - client_class.create_from_edge_environment() - - assert mock_edge_auth_init.call_count == 1 - assert mock_sk_auth_parse.call_count == 0 # we did NOT use SK auth - assert mock_edge_auth_init.call_args == mocker.call( - hostname=edge_container_environment["IOTEDGE_IOTHUBHOSTNAME"], - device_id=edge_container_environment["IOTEDGE_DEVICEID"], - module_id=edge_container_environment["IOTEDGE_MODULEID"], - gateway_hostname=edge_container_environment["IOTEDGE_GATEWAYHOSTNAME"], - module_generation_id=edge_container_environment["IOTEDGE_MODULEGENERATIONID"], - workload_uri=edge_container_environment["IOTEDGE_WORKLOADURI"], - api_version=edge_container_environment["IOTEDGE_APIVERSION"], - ) - - @pytest.mark.it( - "Uses the SymmetricKeyAuthenticationProvider to create an MQTTPipeline and an HTTPPipeline" - ) - async def test_pipeline_creation( - self, mocker, client_class, edge_local_debug_environment, mock_open - ): - mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) - mock_auth = mocker.patch( - "azure.iot.device.iothub.auth.SymmetricKeyAuthenticationProvider" - ).parse.return_value - mock_config = mocker.patch( - "azure.iot.device.iothub.pipeline.IoTHubPipelineConfig" - ).return_value - mock_mqtt_pipeline_init = mocker.patch("azure.iot.device.iothub.pipeline.MQTTPipeline") - mock_http_pipeline_init = mocker.patch("azure.iot.device.iothub.pipeline.HTTPPipeline") - - client_class.create_from_edge_environment() - - assert mock_mqtt_pipeline_init.call_count == 1 - assert mock_mqtt_pipeline_init.call_args == mocker.call(mock_auth, mock_config) - assert mock_http_pipeline_init.call_count == 1 - assert mock_http_pipeline_init.call_args == mocker.call(mock_auth, mock_config) - - @pytest.mark.it("Uses the MQTTPipeline and the HTTPPipeline to instantiate the client") - async def test_client_instantiation( - self, mocker, client_class, edge_local_debug_environment, mock_open - ): - mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) - mock_mqtt_pipeline = mocker.patch( - "azure.iot.device.iothub.pipeline.MQTTPipeline" - ).return_value - mock_http_pipeline = mocker.patch( - "azure.iot.device.iothub.pipeline.HTTPPipeline" - ).return_value - spy_init = mocker.spy(client_class, "__init__") - - client_class.create_from_edge_environment() - - assert spy_init.call_count == 1 - assert spy_init.call_args == mocker.call(mocker.ANY, mock_mqtt_pipeline, mock_http_pipeline) - - @pytest.mark.it("Returns the instantiated client") - async def test_returns_client( - self, mocker, client_class, edge_local_debug_environment, mock_open - ): - mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) - - client = client_class.create_from_edge_environment() - - assert isinstance(client, client_class) - - @pytest.mark.it("Raises OSError if the environment is missing required variables") - @pytest.mark.parametrize( - "missing_env_var", ["EdgeHubConnectionString", "EdgeModuleCACertificateFile"] - ) - async def test_bad_environment( - self, mocker, client_class, edge_local_debug_environment, missing_env_var, mock_open - ): - # Remove a variable from the fixture - del edge_local_debug_environment[missing_env_var] - mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) - - with pytest.raises(OSError): - client_class.create_from_edge_environment() - - # TODO: If auth package was refactored to use ConnectionString class, tests from that - # class would increase the coverage here. - @pytest.mark.it( - "Raises ValueError if the connection string in the EdgeHubConnectionString environment varialbe is invalid" - ) - @pytest.mark.parametrize( - "bad_cs", - [ - pytest.param("not-a-connection-string", id="Garbage string"), - pytest.param("", id="Empty string"), - pytest.param( - "HostName=Invalid;DeviceId=Invalid;ModuleId=Invalid;SharedAccessKey=Invalid;GatewayHostName=Invalid", - id="Malformed Connection String", - marks=pytest.mark.xfail(reason="Bug in pipeline + need for auth refactor"), # TODO - ), - ], - ) - async def test_bad_connection_string( - self, mocker, client_class, edge_local_debug_environment, bad_cs, mock_open - ): - edge_local_debug_environment["EdgeHubConnectionString"] = bad_cs - mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) - - with pytest.raises(ValueError): - client_class.create_from_edge_environment() - - @pytest.mark.it( - "Raises ValueError if the filepath in the EdgeModuleCACertificateFile environment variable is invalid" - ) - async def test_bad_filepath( - self, mocker, client_class, edge_local_debug_environment, mock_open - ): - mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) - error = FileNotFoundError() - mock_open.side_effect = error - with pytest.raises(ValueError) as e_info: - client_class.create_from_edge_environment() - assert e_info.value.__cause__ is error - - @pytest.mark.it( - "Raises ValueError if the file referenced by the filepath in the EdgeModuleCACertificateFile environment variable cannot be opened" - ) - async def test_bad_file_io(self, mocker, client_class, edge_local_debug_environment, mock_open): - mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) - error = OSError() - mock_open.side_effect = error - with pytest.raises(ValueError) as e_info: - client_class.create_from_edge_environment() - assert e_info.value.__cause__ is error + pass @pytest.mark.describe("IoTHubModuleClient (Asynchronous) - .create_from_x509_certificate()") class TestIoTHubModuleClientCreateFromX509Certificate( - IoTHubModuleClientTestsConfig, SharedClientCreateMethodUserOptionTests + IoTHubModuleClientTestsConfig, SharedIoTHubModuleClientCreateFromX509CertificateTests ): - hostname = "durmstranginstitute.farend" - device_id = "MySnitch" - module_id = "Charms" - - @pytest.fixture - def client_create_method(self, client_class): - """Provides the specific create method for use in universal tests""" - return client_class.create_from_x509_certificate - - @pytest.fixture - def create_method_args(self, x509): - """Provides the specific create method args for use in universal tests""" - return [x509, self.hostname, self.device_id, self.module_id] - - @pytest.mark.it("Uses the provided arguments to create a X509AuthenticationProvider") - async def test_auth_provider_creation(self, mocker, client_class, x509): - mock_auth_init = mocker.patch("azure.iot.device.iothub.auth.X509AuthenticationProvider") - - client_class.create_from_x509_certificate( - x509=x509, hostname=self.hostname, device_id=self.device_id, module_id=self.module_id - ) - - assert mock_auth_init.call_count == 1 - assert mock_auth_init.call_args == mocker.call( - x509=x509, hostname=self.hostname, device_id=self.device_id, module_id=self.module_id - ) - - @pytest.mark.it("Uses the X509AuthenticationProvider to create an MQTTPipeline") - async def test_pipeline_creation(self, mocker, client_class, x509, mock_mqtt_pipeline_init): - mock_auth = mocker.patch( - "azure.iot.device.iothub.auth.X509AuthenticationProvider" - ).return_value - - mock_config = mocker.patch( - "azure.iot.device.iothub.pipeline.IoTHubPipelineConfig" - ).return_value - - client_class.create_from_x509_certificate( - x509=x509, hostname=self.hostname, device_id=self.device_id, module_id=self.module_id - ) - - assert mock_mqtt_pipeline_init.call_count == 1 - assert mock_mqtt_pipeline_init.call_args == mocker.call(mock_auth, mock_config) - - @pytest.mark.it("Uses the MQTTPipeline to instantiate the client") - async def test_client_instantiation(self, mocker, client_class, x509): - mock_pipeline = mocker.patch("azure.iot.device.iothub.pipeline.MQTTPipeline").return_value - mock_pipeline_http = mocker.patch( - "azure.iot.device.iothub.pipeline.HTTPPipeline" - ).return_value - spy_init = mocker.spy(client_class, "__init__") - - client_class.create_from_x509_certificate( - x509=x509, hostname=self.hostname, device_id=self.device_id, module_id=self.module_id - ) - - assert spy_init.call_count == 1 - assert spy_init.call_args == mocker.call(mocker.ANY, mock_pipeline, mock_pipeline_http) - - @pytest.mark.it("Returns the instantiated client") - async def test_returns_client(self, mocker, client_class, x509): - client = client_class.create_from_x509_certificate( - x509=x509, hostname=self.hostname, device_id=self.device_id, module_id=self.module_id - ) - - assert isinstance(client, client_class) + pass @pytest.mark.describe("IoTHubModuleClient (Asynchronous) - .connect()") @@ -1937,16 +1106,16 @@ class TestIoTHubNModuleClientSendD2CMessage( @pytest.mark.describe("IoTHubModuleClient (Asynchronous) - .send_message_to_output()") class TestIoTHubModuleClientSendToOutput(IoTHubModuleClientTestsConfig): - @pytest.mark.it("Begins a 'send_output_event' pipeline operation") + @pytest.mark.it("Begins a 'send_output_message' pipeline operation") async def test_calls_pipeline_send_message_to_output(self, client, mqtt_pipeline, message): output_name = "some_output" await client.send_message_to_output(message, output_name) - assert mqtt_pipeline.send_output_event.call_count == 1 - assert mqtt_pipeline.send_output_event.call_args[0][0] is message + assert mqtt_pipeline.send_output_message.call_count == 1 + assert mqtt_pipeline.send_output_message.call_args[0][0] is message assert message.output_name == output_name @pytest.mark.it( - "Waits for the completion of the 'send_output_event' pipeline operation before returning" + "Waits for the completion of the 'send_output_message' pipeline operation before returning" ) async def test_waits_for_pipeline_op_completion(self, mocker, client, mqtt_pipeline, message): cb_mock = mocker.patch.object(async_adapter, "AwaitableCallback").return_value @@ -1956,12 +1125,12 @@ async def test_waits_for_pipeline_op_completion(self, mocker, client, mqtt_pipel await client.send_message_to_output(message, output_name) # Assert callback is sent to pipeline - assert mqtt_pipeline.send_output_event.call_args[1]["callback"] is cb_mock + assert mqtt_pipeline.send_output_message.call_args[1]["callback"] is cb_mock # Assert callback completion is waited upon assert cb_mock.completion.call_count == 1 @pytest.mark.it( - "Raises a client error if the `send_output_event` pipeline operation calls back with a pipeline error" + "Raises a client error if the `send_output_message` pipeline operation calls back with a pipeline error" ) @pytest.mark.parametrize( "pipeline_error,client_error", @@ -1994,15 +1163,15 @@ async def test_raises_error_on_pipeline_op_error( ): my_pipeline_error = pipeline_error() - def fail_send_output_event(message, callback): + def fail_send_output_message(message, callback): callback(error=my_pipeline_error) - mqtt_pipeline.send_output_event = mocker.MagicMock(side_effect=fail_send_output_event) + mqtt_pipeline.send_output_message = mocker.MagicMock(side_effect=fail_send_output_message) with pytest.raises(client_error) as e_info: output_name = "some_output" await client.send_message_to_output(message, output_name) assert e_info.value.__cause__ is my_pipeline_error - assert mqtt_pipeline.send_output_event.call_count == 1 + assert mqtt_pipeline.send_output_message.call_count == 1 @pytest.mark.it( "Wraps 'message' input parameter in Message object if it is not a Message object" @@ -2023,8 +1192,8 @@ async def test_send_message_to_output_calls_pipeline_wraps_data_in_message( ): output_name = "some_output" await client.send_message_to_output(message_input, output_name) - assert mqtt_pipeline.send_output_event.call_count == 1 - sent_message = mqtt_pipeline.send_output_event.call_args[0][0] + assert mqtt_pipeline.send_output_message.call_count == 1 + sent_message = mqtt_pipeline.send_output_message.call_args[0][0] assert isinstance(sent_message, Message) assert sent_message.data == message_input @@ -2038,7 +1207,7 @@ async def test_raises_error_when_message_to_output_data_greater_than_256( with pytest.raises(ValueError) as e_info: await client.send_message_to_output(message, output_name) assert "256 KB" in e_info.value.args[0] - assert mqtt_pipeline.send_output_event.call_count == 0 + assert mqtt_pipeline.send_output_message.call_count == 0 @pytest.mark.it("Raises error when message size is greater than 256 KB") async def test_raises_error_when_message_to_output_size_greater_than_256( @@ -2051,7 +1220,7 @@ async def test_raises_error_when_message_to_output_size_greater_than_256( with pytest.raises(ValueError) as e_info: await client.send_message_to_output(message, output_name) assert "256 KB" in e_info.value.args[0] - assert mqtt_pipeline.send_output_event.call_count == 0 + assert mqtt_pipeline.send_output_message.call_count == 0 @pytest.mark.it("Does not raises error when message data size is equal to 256 KB") async def test_raises_error_when_message_to_output_data_equal_to_256( @@ -2067,8 +1236,8 @@ async def test_raises_error_when_message_to_output_data_equal_to_256( await client.send_message_to_output(message, output_name) - assert mqtt_pipeline.send_output_event.call_count == 1 - sent_message = mqtt_pipeline.send_output_event.call_args[0][0] + assert mqtt_pipeline.send_output_message.call_count == 1 + sent_message = mqtt_pipeline.send_output_message.call_args[0][0] assert isinstance(sent_message, Message) assert sent_message.data == data_input @@ -2230,6 +1399,6 @@ def fail_invoke_method(method_params, device_id, callback, module_id=None): @pytest.mark.describe("IoTHubModule (Asynchronous) - PROPERTY .connected") class TestIoTHubModuleClientPROPERTYConnected( - IoTHubModuleClientTestsConfig, SharedClientPROPERTYConnectedTests + IoTHubModuleClientTestsConfig, SharedIoTHubClientPROPERTYConnectedTests ): pass diff --git a/azure-iot-device/tests/iothub/auth/conftest.py b/azure-iot-device/tests/iothub/auth/conftest.py deleted file mode 100644 index 0dcf5af94..000000000 --- a/azure-iot-device/tests/iothub/auth/conftest.py +++ /dev/null @@ -1,7 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -from .shared_auth_fixtures import * diff --git a/azure-iot-device/tests/iothub/auth/shared_auth_fixtures.py b/azure-iot-device/tests/iothub/auth/shared_auth_fixtures.py deleted file mode 100644 index fdc5b21b2..000000000 --- a/azure-iot-device/tests/iothub/auth/shared_auth_fixtures.py +++ /dev/null @@ -1,22 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import pytest - - -@pytest.fixture -def hostname(): - return "__FAKE_HOSTNAME__" - - -@pytest.fixture -def device_id(): - return "__FAKE_DEVICE_ID__" - - -@pytest.fixture -def module_id(): - return "__FAKE_MODULE__ID__" diff --git a/azure-iot-device/tests/iothub/auth/shared_auth_tests.py b/azure-iot-device/tests/iothub/auth/shared_auth_tests.py deleted file mode 100644 index c90074b3a..000000000 --- a/azure-iot-device/tests/iothub/auth/shared_auth_tests.py +++ /dev/null @@ -1,28 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import pytest - - -class SharedAuthenticationProviderInstantiationTests(object): - @pytest.mark.it("Sets the hostname parameter as an instance attribute") - def test_hostname(self, auth_provider, hostname): - assert auth_provider.hostname == hostname - - @pytest.mark.it("Sets the device_id parameter as an instance attribute") - def test_device_id(self, auth_provider, device_id): - assert auth_provider.device_id == device_id - - @pytest.mark.it("Sets the module_id parameter as an instance attribute") - def test_module_id(self, auth_provider, module_id): - assert auth_provider.module_id == module_id - - -class SharedBaseRenewableAuthenticationProviderInstantiationTests( - SharedAuthenticationProviderInstantiationTests -): - # TODO: Complete this testclass after refactoring the class under test - pass diff --git a/azure-iot-device/tests/iothub/auth/test_base_renewable_token_authentication_provider.py b/azure-iot-device/tests/iothub/auth/test_base_renewable_token_authentication_provider.py deleted file mode 100644 index a3e40a453..000000000 --- a/azure-iot-device/tests/iothub/auth/test_base_renewable_token_authentication_provider.py +++ /dev/null @@ -1,180 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import pytest -import logging -from mock import MagicMock, patch -from threading import Timer -from azure.iot.device.iothub.auth.base_renewable_token_authentication_provider import ( - BaseRenewableTokenAuthenticationProvider, - DEFAULT_TOKEN_VALIDITY_PERIOD, - DEFAULT_TOKEN_RENEWAL_MARGIN, -) - -logging.basicConfig(level=logging.DEBUG) - - -fake_signature = "__FAKE_SIGNATURE__" -fake_hostname = "__FAKE_HOSTNAME__" -fake_device_id = "__FAKE_DEVICE_ID__" -fake_module_id = "__FAKE_MODULE_ID__" -fake_current_time = 123456 -fake_device_resource_uri = "{}%2Fdevices%2F{}".format(fake_hostname, fake_device_id) -fake_module_resource_uri = "{}%2Fdevices%2F{}%2Fmodules%2F{}".format( - fake_hostname, fake_device_id, fake_module_id -) -fake_device_token_base = "SharedAccessSignature sr={}&sig={}&se=".format( - fake_device_resource_uri, fake_signature -) -fake_module_token_base = "SharedAccessSignature sr={}&sig={}&se=".format( - fake_module_resource_uri, fake_signature -) -new_token_validity_period = 8675 -new_token_renewal_margin = 309 - - -class FakeAuthProvider(BaseRenewableTokenAuthenticationProvider): - def __init__(self, hostname, device_id, module_id): - BaseRenewableTokenAuthenticationProvider.__init__(self, hostname, device_id, module_id) - self._sign = MagicMock(return_value=fake_signature) - - def _sign(self, quoted_resource_uri, expiry): - pass - - def parse(source): - pass - - -@pytest.fixture(scope="function") -def device_auth_provider(): - return FakeAuthProvider(fake_hostname, fake_device_id, None) - - -@pytest.fixture(scope="function") -def module_auth_provider(): - return FakeAuthProvider(fake_hostname, fake_device_id, fake_module_id) - - -@pytest.fixture(scope="function") -def fake_get_current_time_function(): - with patch( - "azure.iot.device.iothub.auth.base_renewable_token_authentication_provider.time.time", - MagicMock(return_value=fake_current_time), - ): - yield - - -@pytest.fixture(scope="function") -def fake_timer_object(): - with patch( - "azure.iot.device.iothub.auth.base_renewable_token_authentication_provider.Timer", - MagicMock(spec=Timer), - ) as PatchedTimer: - yield PatchedTimer - - -def test_device_get_current_sas_token_generates_and_returns_new_sas_token( - device_auth_provider, fake_get_current_time_function -): - token = device_auth_provider.get_current_sas_token() - assert device_auth_provider._sign.call_count == 1 - assert token == fake_device_token_base + str(fake_current_time + DEFAULT_TOKEN_VALIDITY_PERIOD) - - -def test_module_get_current_sas_token_generates_and_returns_new_sas_token( - module_auth_provider, fake_get_current_time_function -): - token = module_auth_provider.get_current_sas_token() - assert module_auth_provider._sign.call_count == 1 - assert token == fake_module_token_base + str(fake_current_time + DEFAULT_TOKEN_VALIDITY_PERIOD) - - -def test_get_current_sas_token_returns_existing_sas_token(device_auth_provider): - token1 = device_auth_provider.get_current_sas_token() - token2 = device_auth_provider.get_current_sas_token() - assert device_auth_provider._sign.call_count == 1 - assert token1 == token2 - - -def test_generate_new_sas_token_calls_on_sas_token_updated_handler_when_sas_updates( - device_auth_provider -): - update_callback_list = [MagicMock(), MagicMock(), MagicMock()] - device_auth_provider.on_sas_token_updated_handler_list = update_callback_list - device_auth_provider.generate_new_sas_token() - for x in update_callback_list: - x.assert_called_once_with() - - -def test_device_generate_new_sas_token_calls_sign_with_correct_default_args( - device_auth_provider, fake_get_current_time_function -): - device_auth_provider.generate_new_sas_token() - resource_uri = device_auth_provider._sign.call_args[0][0] - expiry = device_auth_provider._sign.call_args[0][1] - assert resource_uri == fake_device_resource_uri - assert expiry == fake_current_time + DEFAULT_TOKEN_VALIDITY_PERIOD - - -def test_module_generate_new_sas_token_calls_sign_with_correct_default_args( - module_auth_provider, fake_get_current_time_function -): - module_auth_provider.generate_new_sas_token() - resource_uri = module_auth_provider._sign.call_args[0][0] - expiry = module_auth_provider._sign.call_args[0][1] - assert resource_uri == fake_module_resource_uri - assert expiry == fake_current_time + DEFAULT_TOKEN_VALIDITY_PERIOD - - -def test_generate_new_sas_token_calls_sign_with_correct_modified_expiry( - device_auth_provider, fake_get_current_time_function -): - device_auth_provider.token_validity_period = new_token_validity_period - device_auth_provider.token_renewal_margin = new_token_renewal_margin - device_auth_provider.generate_new_sas_token() - expiry = device_auth_provider._sign.call_args[0][1] - assert expiry == fake_current_time + new_token_validity_period - - -def test_generate_new_sas_token_schedules_update_timer_with_correct_default_timeout( - device_auth_provider, fake_timer_object -): - device_auth_provider.generate_new_sas_token() - assert ( - fake_timer_object.call_args[0][0] - == DEFAULT_TOKEN_VALIDITY_PERIOD - DEFAULT_TOKEN_RENEWAL_MARGIN - ) - - -def test_generate_new_sas_token_cancels_and_reschedules_update_timer_with_correct_modified_timeout( - device_auth_provider, fake_timer_object -): - device_auth_provider.token_validity_period = new_token_validity_period - device_auth_provider.token_renewal_margin = new_token_renewal_margin - device_auth_provider.generate_new_sas_token() - assert fake_timer_object.call_args[0][0] == new_token_validity_period - new_token_renewal_margin - - -def test_update_timer_generates_new_sas_token_and_calls_on_sas_token_updated_handler( - device_auth_provider, fake_timer_object -): - update_callback_list = [MagicMock(), MagicMock(), MagicMock()] - device_auth_provider.generate_new_sas_token() - device_auth_provider.on_sas_token_updated_handler_list = update_callback_list - timer_callback = fake_timer_object.call_args[0][1] - device_auth_provider._sign.reset_mock() - timer_callback() - for x in update_callback_list: - x.assert_called_once_with() - assert device_auth_provider._sign.call_count == 1 - - -def test_finalizer_cancels_update_timer(fake_timer_object): - # can't use the device_auth_provider fixture here because the fixture adds - # to the object refcount and prevents del from calling the finalizer - device_auth_provider = FakeAuthProvider(fake_hostname, fake_device_id, None) - device_auth_provider.generate_new_sas_token() - del device_auth_provider - fake_timer_object.return_value.cancel.assert_called_once_with() diff --git a/azure-iot-device/tests/iothub/auth/test_iotedge_authentication_provider.py b/azure-iot-device/tests/iothub/auth/test_iotedge_authentication_provider.py deleted file mode 100644 index 093243997..000000000 --- a/azure-iot-device/tests/iothub/auth/test_iotedge_authentication_provider.py +++ /dev/null @@ -1,341 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import pytest -import requests -import json -import base64 -import logging -import six.moves.urllib as urllib -from azure.iot.device.iothub.auth.iotedge_authentication_provider import ( - IoTEdgeAuthenticationProvider, - IoTEdgeHsm, - IoTEdgeError, -) -from .shared_auth_tests import SharedBaseRenewableAuthenticationProviderInstantiationTests -from azure.iot.device.product_info import ProductInfo - -logging.basicConfig(level=logging.DEBUG) - - -@pytest.fixture -def gateway_hostname(): - return "__FAKE_GATEWAY_HOSTNAME__" - - -@pytest.fixture -def module_generation_id(): - return "__FAKE_MODULE_GENERATION_ID__" - - -@pytest.fixture -def workload_uri(): - return "http://__FAKE_WORKLOAD_URI__/" - - -@pytest.fixture -def api_version(): - return "__FAKE_API_VERSION__" - - -@pytest.fixture -def certificate(): - return "__FAKE_CERTIFICATE__" - - -@pytest.fixture -def mock_hsm(mocker, certificate): - mock_hsm = mocker.patch( - "azure.iot.device.iothub.auth.iotedge_authentication_provider.IoTEdgeHsm" - ).return_value - mock_hsm.get_trust_bundle.return_value = certificate - return mock_hsm - - -@pytest.fixture -def hsm(module_id, module_generation_id, workload_uri, api_version): - return IoTEdgeHsm( - module_id=module_id, - module_generation_id=module_generation_id, - workload_uri=workload_uri, - api_version=api_version, - ) - - -@pytest.fixture -def auth_provider( - mock_hsm, - hostname, - device_id, - module_id, - gateway_hostname, - module_generation_id, - workload_uri, - api_version, -): - return IoTEdgeAuthenticationProvider( - hostname=hostname, - device_id=device_id, - module_id=module_id, - gateway_hostname=gateway_hostname, - module_generation_id=module_generation_id, - workload_uri=workload_uri, - api_version=api_version, - ) - - -####################################### -# IoTEdgeAuthenticationProvider Tests # -####################################### - - -@pytest.mark.describe("IoTEdgeAuthenticationProvider - Instantiation") -class TestIoTEdgeAuthenticationProviderInstantiation( - SharedBaseRenewableAuthenticationProviderInstantiationTests -): - - # TODO: Increase coverage by completing parent class - - @pytest.mark.it("Sets the gateway_hostname parameter as an instance attribute") - def test_gateway_hostname(self, auth_provider, gateway_hostname): - assert auth_provider.gateway_hostname == gateway_hostname - - @pytest.mark.it("Creates an instance of the IoTEdgeHsm") - def test_creates_edge_hsm(self, auth_provider, mock_hsm): - assert auth_provider.hsm is mock_hsm - - @pytest.mark.it( - "Sets a certificate acquired from the IoTEdgeHsm as the server_verification_cert instance attribute" - ) - def test_server_verification_cert_from_edge_hsm(self, auth_provider, mock_hsm): - assert auth_provider.server_verification_cert is mock_hsm.get_trust_bundle.return_value - assert mock_hsm.get_trust_bundle.call_count == 1 - - -# TODO: Potentially get rid of this test class depending on how the parent class is tested/refactored. -# After all, we really shouldn't be testing convention-private methods. -@pytest.mark.describe("IoTEdgeAuthenticationProvider - ._sign()") -class TestIoTEdgeAuthenticationProviderSign(object): - @pytest.mark.it("Requests signing of a string in the format '/n'") - def test_sign_request(self, mocker, auth_provider, mock_hsm): - uri = "my/resource/uri" - expiry = 1234567 - string_to_sign = uri + "\n" + str(expiry) - - auth_provider._sign(uri, expiry) - - assert mock_hsm.sign.call_count == 1 - assert mock_hsm.sign.call_args == mocker.call(string_to_sign) - - @pytest.mark.it("Returns the signed string provided by the IoTEdgeHsm") - def test_returns_signed_response(self, auth_provider, mock_hsm): - uri = "my/resource/uri" - expiry = 1234567 - - signed_string = auth_provider._sign(uri, expiry) - - assert signed_string is mock_hsm.sign.return_value - - -#################### -# IoTEdgeHsm Tests # -#################### - - -@pytest.mark.describe("IoTEdgeHsm - Instantiation") -class TestIoTEdgeHsmInstantiation(object): - @pytest.mark.it("URL encodes the module_id parameter and sets it as an instance attribute") - def test_module_id(self, module_generation_id, workload_uri, api_version): - my_module_id = "not url //encoded" - expected_module_id = urllib.parse.quote(my_module_id) - hsm = IoTEdgeHsm( - module_id=my_module_id, - module_generation_id=module_generation_id, - workload_uri=workload_uri, - api_version=api_version, - ) - - assert my_module_id != expected_module_id - assert hsm.module_id == expected_module_id - - @pytest.mark.it("Sets the module_generation_id paramater as an instance attribute") - def test_module_generation_id(self, hsm, module_generation_id): - assert hsm.module_generation_id == module_generation_id - - @pytest.mark.it( - "Converts the workload_uri parameter into requests-unixsocket format and sets it as an instance attribute" - ) - def test_workload_uri(self, module_id, module_generation_id, api_version): - my_workload_uri = "unix:///var/run/iotedge/workload.sock" - expected_workload_uri = "http+unix://%2Fvar%2Frun%2Fiotedge%2Fworkload.sock/" - hsm = IoTEdgeHsm( - module_id=module_id, - module_generation_id=module_generation_id, - workload_uri=my_workload_uri, - api_version=api_version, - ) - - assert hsm.workload_uri == expected_workload_uri - - @pytest.mark.it("Sets the api_version paramater as an instance attribute") - def test_api_version(self, hsm, api_version): - assert hsm.api_version == api_version - - -@pytest.mark.describe("IoTEdgeHsm - .get_trust_bundle()") -class TestIoTEdgeHsmGetTrustBundle(object): - @pytest.mark.it("Makes an HTTP request to EdgeHub for the trust bundle") - def test_requests_trust_bundle(self, mocker, hsm): - mock_request_get = mocker.patch.object(requests, "get") - expected_url = hsm.workload_uri + "trust-bundle" - expected_params = {"api-version": hsm.api_version} - expected_headers = { - "User-Agent": urllib.parse.quote_plus(ProductInfo.get_iothub_user_agent()) - } - - hsm.get_trust_bundle() - - assert mock_request_get.call_count == 1 - assert mock_request_get.call_args == mocker.call( - expected_url, params=expected_params, headers=expected_headers - ) - - @pytest.mark.it("Returns the certificate from the trust bundle received from EdgeHub") - def test_returns_received_trust_bundle(self, mocker, hsm, certificate): - mock_request_get = mocker.patch.object(requests, "get") - mock_response = mock_request_get.return_value - mock_response.json.return_value = {"certificate": certificate} - - cert = hsm.get_trust_bundle() - - assert cert is certificate - - @pytest.mark.it("Raises IoTEdgeError if a bad request is made to EdgeHub") - def test_bad_request(self, mocker, hsm): - mock_request_get = mocker.patch.object(requests, "get") - mock_response = mock_request_get.return_value - error = requests.exceptions.HTTPError() - mock_response.raise_for_status.side_effect = error - - with pytest.raises(IoTEdgeError) as e_info: - hsm.get_trust_bundle() - assert e_info.value.__cause__ is error - - @pytest.mark.it("Raises IoTEdgeError if there is an error in json decoding the trust bundle") - def test_bad_json(self, mocker, hsm): - mock_request_get = mocker.patch.object(requests, "get") - mock_response = mock_request_get.return_value - error = ValueError() - mock_response.json.side_effect = error - - with pytest.raises(IoTEdgeError) as e_info: - hsm.get_trust_bundle() - assert e_info.value.__cause__ is error - - @pytest.mark.it("Raises IoTEdgeError if the certificate is missing from the trust bundle") - def test_bad_trust_bundle(self, mocker, hsm): - mock_request_get = mocker.patch.object(requests, "get") - mock_response = mock_request_get.return_value - # Return an empty json dict with no 'certificate' key - mock_response.json.return_value = {} - - with pytest.raises(IoTEdgeError): - hsm.get_trust_bundle() - - -@pytest.mark.describe("IoTEdgeHsm - .sign()") -class TestIoTEdgeHsmSign(object): - @pytest.mark.it("Makes an HTTP request to EdgeHub to sign a piece of string data") - def test_requests_data_signing(self, mocker, hsm): - data_str = "somedata" - data_str_b64 = "c29tZWRhdGE=" - mock_request_post = mocker.patch.object(requests, "post") - mock_request_post.return_value.json.return_value = {"digest": "somedigest"} - expected_url = "{workload_uri}modules/{module_id}/genid/{module_generation_id}/sign".format( - workload_uri=hsm.workload_uri, - module_id=hsm.module_id, - module_generation_id=hsm.module_generation_id, - ) - expected_params = {"api-version": hsm.api_version} - expected_headers = { - "User-Agent": urllib.parse.quote_plus(ProductInfo.get_iothub_user_agent()) - } - expected_json = json.dumps({"keyId": "primary", "algo": "HMACSHA256", "data": data_str_b64}) - - hsm.sign(data_str) - - assert mock_request_post.call_count == 1 - assert mock_request_post.call_args == mocker.call( - url=expected_url, params=expected_params, headers=expected_headers, data=expected_json - ) - - @pytest.mark.it("Base64 encodes the string data in the request") - def test_b64_encodes_data(self, mocker, hsm): - # This test is actually implicitly tested in the first test, but it's - # important to have an explicit test for it since it's a requirement - data_str = "somedata" - data_str_b64 = base64.b64encode(data_str.encode("utf-8")).decode() - mock_request_post = mocker.patch.object(requests, "post") - mock_request_post.return_value.json.return_value = {"digest": "somedigest"} - - hsm.sign(data_str) - - sent_data = json.loads(mock_request_post.call_args[1]["data"])["data"] - - assert data_str != data_str_b64 - assert sent_data == data_str_b64 - - @pytest.mark.it("Returns the signed data received from EdgeHub") - def test_returns_signed_data(self, mocker, hsm): - expected_digest = "somedigest" - mock_request_post = mocker.patch.object(requests, "post") - mock_request_post.return_value.json.return_value = {"digest": expected_digest} - - signed_data = hsm.sign("somedata") - - assert signed_data == expected_digest - - @pytest.mark.it("URL encodes the signed data before returning it") - def test_url_encodes_signed_data(self, mocker, hsm): - raw_signed_data = "this digest will be encoded" - expected_signed_data = urllib.parse.quote(raw_signed_data) - mock_request_post = mocker.patch.object(requests, "post") - mock_request_post.return_value.json.return_value = {"digest": raw_signed_data} - - signed_data = hsm.sign("somedata") - - assert raw_signed_data != expected_signed_data - assert signed_data == expected_signed_data - - @pytest.mark.it("Raises IoTEdgeError if a bad request is made to EdgeHub") - def test_bad_request(self, mocker, hsm): - mock_request_post = mocker.patch.object(requests, "post") - mock_response = mock_request_post.return_value - error = requests.exceptions.HTTPError() - mock_response.raise_for_status.side_effect = error - - with pytest.raises(IoTEdgeError) as e_info: - hsm.sign("somedata") - assert e_info.value.__cause__ is error - - @pytest.mark.it("Raises IoTEdgeError if there is an error in json decoding the signed response") - def test_bad_json(self, mocker, hsm): - mock_request_post = mocker.patch.object(requests, "post") - mock_response = mock_request_post.return_value - error = ValueError() - mock_response.json.side_effect = error - with pytest.raises(IoTEdgeError) as e_info: - hsm.sign("somedata") - assert e_info.value.__cause__ is error - - @pytest.mark.it("Raises IoTEdgeError if the signed data is missing from the response") - def test_bad_response(self, mocker, hsm): - mock_request_post = mocker.patch.object(requests, "post") - mock_response = mock_request_post.return_value - mock_response.json.return_value = {} - - with pytest.raises(IoTEdgeError): - hsm.sign("somedata") diff --git a/azure-iot-device/tests/iothub/auth/test_sas_authentication_provider.py b/azure-iot-device/tests/iothub/auth/test_sas_authentication_provider.py deleted file mode 100644 index 9364165f7..000000000 --- a/azure-iot-device/tests/iothub/auth/test_sas_authentication_provider.py +++ /dev/null @@ -1,195 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import pytest -import logging -from azure.iot.device.iothub.auth.sas_authentication_provider import ( - SharedAccessSignatureAuthenticationProvider, -) - -logging.basicConfig(level=logging.DEBUG) - - -sas_device_token_format = "SharedAccessSignature sr={}&sig={}&se={}" -sas_device_skn_token_format = "SharedAccessSignature sr={}&sig={}&se={}&skn={}" - - -shared_access_key_name = "alohomora" -hostname = "beauxbatons.academy-net" -device_id = "MyPensieve" -module_id = "Divination" - -signature = "IsolemnlySwearThatIamuUptoNogood" -expiry = "1539043658" - - -def create_sas_token_string_device(is_module=False, is_key_name=False): - uri = hostname + "/devices/" + device_id - if is_module: - uri = uri + "/modules/" + module_id - if is_key_name: - return sas_device_skn_token_format.format(uri, signature, expiry, shared_access_key_name) - else: - return sas_device_token_format.format(uri, signature, expiry) - - -def test_sas_auth_provider_is_created_from_device_sas_token_string(): - sas_string = create_sas_token_string_device() - sas_auth_provider = SharedAccessSignatureAuthenticationProvider.parse(sas_string) - assert sas_auth_provider.hostname == hostname - assert sas_auth_provider.device_id == device_id - assert hostname in sas_auth_provider.sas_token_str - assert device_id in sas_auth_provider.sas_token_str - - -def test_sas_auth_provider_is_created_from_module_sas_token_string(): - sas_string = create_sas_token_string_device(True) - sas_auth_provider = SharedAccessSignatureAuthenticationProvider.parse(sas_string) - assert sas_auth_provider.hostname == hostname - assert sas_auth_provider.device_id == device_id - assert hostname in sas_auth_provider.sas_token_str - assert device_id in sas_auth_provider.sas_token_str - assert sas_auth_provider.module_id == module_id - assert hostname in sas_auth_provider.sas_token_str - assert device_id in sas_auth_provider.sas_token_str - assert module_id in sas_auth_provider.sas_token_str - - -def test_sas_auth_provider_is_created_from_device_sas_token_string_with_keyname(): - sas_string = create_sas_token_string_device(False, True) - sas_auth_provider = SharedAccessSignatureAuthenticationProvider.parse(sas_string) - assert sas_auth_provider.hostname == hostname - assert sas_auth_provider.device_id == device_id - assert hostname in sas_auth_provider.sas_token_str - assert device_id in sas_auth_provider.sas_token_str - assert shared_access_key_name in sas_auth_provider.sas_token_str - - -def test_sas_auth_provider_is_created_from_device_sas_token_string_quoted(): - sas_string_quoted = "SharedAccessSignature sr=beauxbatons.academy-net%2Fdevices%2FMyPensieve&sig=IsolemnlySwearThatIamuUptoNogood&se=1539043658&skn=alohomora" - sas_auth_provider = SharedAccessSignatureAuthenticationProvider.parse(sas_string_quoted) - assert sas_auth_provider.hostname == hostname - assert sas_auth_provider.device_id == device_id - assert hostname in sas_auth_provider.sas_token_str - assert device_id in sas_auth_provider.sas_token_str - - -def test_raises_when_auth_provider_created_from_empty_shared_access_signature_string(): - with pytest.raises( - ValueError, - match="The Shared Access Signature is required and should not be empty or blank and must be supplied as a string consisting of two parts in the format 'SharedAccessSignature sr=&sig=&se=' with an optional skn=", - ): - SharedAccessSignatureAuthenticationProvider.parse("") - - -def test_raises_when_auth_provider_created_from_none_shared_access_signature_string(): - with pytest.raises( - ValueError, - match="The Shared Access Signature is required and should not be empty or blank and must be supplied as a string consisting of two parts in the format 'SharedAccessSignature sr=&sig=&se=' with an optional skn=", - ): - SharedAccessSignatureAuthenticationProvider.parse(None) - - -def test_raises_when_auth_provider_created_from_blank_shared_access_signature_string(): - with pytest.raises( - ValueError, - match="The Shared Access Signature is required and should not be empty or blank and must be supplied as a string consisting of two parts in the format 'SharedAccessSignature sr=&sig=&se=' with an optional skn=", - ): - SharedAccessSignatureAuthenticationProvider.parse(" ") - - -def test_raises_when_auth_provider_created_from_numeric_shared_access_signature_string(): - with pytest.raises( - ValueError, - match="The Shared Access Signature is required and should not be empty or blank and must be supplied as a string consisting of two parts in the format 'SharedAccessSignature sr=&sig=&se=' with an optional skn=", - ): - SharedAccessSignatureAuthenticationProvider.parse(873915) - - -def test_raises_when_auth_provider_created_from_object_shared_access_signature_string(): - with pytest.raises( - ValueError, - match="The Shared Access Signature is required and should not be empty or blank and must be supplied as a string consisting of two parts in the format 'SharedAccessSignature sr=&sig=&se=' with an optional skn=", - ): - SharedAccessSignatureAuthenticationProvider.parse(object) - - -def test_raises_when_auth_provider_created_from_shared_access_signature_string_blank_second_part(): - with pytest.raises( - ValueError, - match="The Shared Access Signature is required and should not be empty or blank and must be supplied as a string consisting of two parts in the format 'SharedAccessSignature sr=&sig=&se=' with an optional skn=", - ): - SharedAccessSignatureAuthenticationProvider.parse("SharedAccessSignature ") - - -def test_raises_when_auth_provider_created_from_shared_access_signature_string_numeric_second_part(): - with pytest.raises( - ValueError, - match="The Shared Access Signature is required and should not be empty or blank and must be supplied as a string consisting of two parts in the format 'SharedAccessSignature sr=&sig=&se=' with an optional skn=", - ): - SharedAccessSignatureAuthenticationProvider.parse("SharedAccessSignature 67998311999") - - -def test_raises_when_auth_provider_created_from_shared_access_signature_string_numeric_value_second_part(): - with pytest.raises( - ValueError, - match="One of the name value pair of the Shared Access Signature string should be a proper resource uri", - ): - SharedAccessSignatureAuthenticationProvider.parse( - "SharedAccessSignature sr=67998311999&sig=24234234&se=1539043658&skn=25245245" - ) - - -def test_raises_when_auth_provider_created_from_shared_access_signature_string_with_incomplete_sr(): - with pytest.raises( - ValueError, - match="One of the name value pair of the Shared Access Signature string should be a proper resource uri", - ): - SharedAccessSignatureAuthenticationProvider.parse( - "SharedAccessSignature sr=MyPensieve&sig=IsolemnlySwearThatIamuUptoNogood&se=1539043658&skn=alohomora" - ) - - -def test_raises_auth_provider_created_from_missing_part_shared_access_signature_string(): - with pytest.raises( - ValueError, - match="The Shared Access Signature is required and should not be empty or blank and must be supplied as a string consisting of two parts in the format 'SharedAccessSignature sr=&sig=&se=' with an optional skn=", - ): - one_part_sas_str = "sr=beauxbatons.academy-net%2Fdevices%2FMyPensieve&sig=IsolemnlySwearThatIamuUptoNogood&se=1539043658&skn=alohomora" - SharedAccessSignatureAuthenticationProvider.parse(one_part_sas_str) - - -def test_raises_auth_provider_created_from_more_parts_shared_access_signature_string(): - with pytest.raises( - ValueError, - match="The Shared Access Signature must be of the format 'SharedAccessSignature sr=&sig=&se=' or/and it can additionally contain an optional skn= name=value pair.", - ): - more_part_sas_str = "SharedAccessSignature sr=beauxbatons.academy-net%2Fdevices%2FMyPensieve&sig=IsolemnlySwearThatIamuUptoNogood&se=1539043658&skn=alohomora SharedAccessSignature" - SharedAccessSignatureAuthenticationProvider.parse(more_part_sas_str) - - -def test_raises_auth_provider_created_from_shared_access_signature_string_duplicate_keys(): - with pytest.raises(ValueError, match="Invalid Shared Access Signature - Unable to parse"): - duplicate_sas_str = "SharedAccessSignature sr=beauxbatons.academy-net%2Fdevices%2FMyPensieve&sig=IsolemnlySwearThatIamuUptoNogood&se=1539043658&sr=alohomora" - SharedAccessSignatureAuthenticationProvider.parse(duplicate_sas_str) - - -def test_raises_auth_provider_created_from_shared_access_signature_string_bad_keys(): - with pytest.raises( - ValueError, - match="Invalid keys in Shared Access Signature. The valid keys are sr, sig, se and an optional skn.", - ): - bad_key_sas_str = "SharedAccessSignature sr=beauxbatons.academy-net%2Fdevices%2FMyPensieve&signature=IsolemnlySwearThatIamuUptoNogood&se=1539043658&skn=alohomora" - SharedAccessSignatureAuthenticationProvider.parse(bad_key_sas_str) - - -def test_raises_auth_provider_created_from_incomplete_shared_access_signature_string(): - with pytest.raises( - ValueError, - match="Invalid Shared Access Signature. It must be of the format 'SharedAccessSignature sr=&sig=&se=' or/and it can additionally contain an optional skn= name=value pair.", - ): - incomplete_sas_str = "SharedAccessSignature sr=beauxbatons.academy-net%2Fdevices%2FMyPensieve&se=1539043658&skn=alohomora" - SharedAccessSignatureAuthenticationProvider.parse(incomplete_sas_str) diff --git a/azure-iot-device/tests/iothub/auth/test_sk_authentication_provider.py b/azure-iot-device/tests/iothub/auth/test_sk_authentication_provider.py deleted file mode 100644 index 2f9f31dce..000000000 --- a/azure-iot-device/tests/iothub/auth/test_sk_authentication_provider.py +++ /dev/null @@ -1,143 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import pytest -import logging -from azure.iot.device.iothub.auth.sk_authentication_provider import ( - SymmetricKeyAuthenticationProvider, -) - -from mock import MagicMock - -logging.basicConfig(level=logging.DEBUG) - - -connection_string_device_sk_format = "HostName={};DeviceId={};SharedAccessKey={}" -connection_string_device_skn_format = ( - "HostName={};DeviceId={};SharedAccessKeyName={};SharedAccessKey={}" -) -connection_string_module_sk_format = "HostName={};DeviceId={};ModuleId={};SharedAccessKey={}" - -shared_access_key = "Zm9vYmFy" -shared_access_key_name = "alohomora" -hostname = "beauxbatons.academy-net" -device_id = "MyPensieve" -module_id = "Divination" - - -def test_all_attributes_for_device(): - connection_string = connection_string_device_sk_format.format( - hostname, device_id, shared_access_key - ) - sym_key_auth_provider = SymmetricKeyAuthenticationProvider.parse(connection_string) - - assert sym_key_auth_provider.device_id == device_id - assert hostname in sym_key_auth_provider.get_current_sas_token() - assert device_id in sym_key_auth_provider.get_current_sas_token() - - -def test_all_attributes_for_module(): - connection_string = connection_string_module_sk_format.format( - hostname, device_id, module_id, shared_access_key - ) - sym_key_auth_provider = SymmetricKeyAuthenticationProvider.parse(connection_string) - - assert sym_key_auth_provider.hostname == hostname - assert sym_key_auth_provider.device_id == device_id - assert sym_key_auth_provider.module_id == module_id - assert hostname in sym_key_auth_provider.get_current_sas_token() - assert device_id in sym_key_auth_provider.get_current_sas_token() - assert module_id in sym_key_auth_provider.get_current_sas_token() - - -def test_sastoken_keyname_device(): - connection_string = connection_string_device_skn_format.format( - hostname, device_id, shared_access_key_name, shared_access_key - ) - - sym_key_auth_provider = SymmetricKeyAuthenticationProvider.parse(connection_string) - - assert hostname in sym_key_auth_provider.get_current_sas_token() - assert device_id in sym_key_auth_provider.get_current_sas_token() - assert shared_access_key_name in sym_key_auth_provider.get_current_sas_token() - - -def test_raises_when_auth_provider_created_from_empty_connection_string(): - with pytest.raises( - ValueError, - match="Connection string is required and should not be empty or blank and must be supplied as a string", - ): - SymmetricKeyAuthenticationProvider.parse("") - - -def test_raises_when_auth_provider_created_from_none_connection_string(): - with pytest.raises( - ValueError, - match="Connection string is required and should not be empty or blank and must be supplied as a string", - ): - SymmetricKeyAuthenticationProvider.parse(None) - - -def test_raises_when_auth_provider_created_from_blank_connection_string(): - with pytest.raises( - ValueError, - match="Connection string is required and should not be empty or blank and must be supplied as a string", - ): - SymmetricKeyAuthenticationProvider.parse(" ") - - -def test_raises_when_auth_provider_created_from_numeric_connection_string(): - with pytest.raises( - ValueError, - match="Connection string is required and should not be empty or blank and must be supplied as a string", - ): - SymmetricKeyAuthenticationProvider.parse(654354) - - -def test_raises_when_auth_provider_created_from_connection_string_object(): - with pytest.raises( - ValueError, - match="Connection string is required and should not be empty or blank and must be supplied as a string", - ): - SymmetricKeyAuthenticationProvider.parse(object) - - -def test_raises_when_auth_provider_created_connection_string_with_numeric_argument(): - with pytest.raises( - ValueError, - match="Connection string is required and should not be empty or blank and must be supplied as a string", - ): - connection_string = "HostName^43443434" - SymmetricKeyAuthenticationProvider.parse(connection_string) - - -def test_raises_when_auth_provider_created_from_incomplete_connection_string(): - with pytest.raises(ValueError, match="Invalid Connection String - Incomplete"): - connection_string = "HostName=beauxbatons.academy-net;SharedAccessKey=Zm9vYmFy" - SymmetricKeyAuthenticationProvider.parse(connection_string) - - -def test_raises_when_auth_provider_created_from_connection_string_with_duplicatekeys(): - with pytest.raises(ValueError, match="Invalid Connection String - Unable to parse"): - connection_string = ( - "HostName=beauxbatons.academy-net;HostName=TheDeluminator;HostName=Zm9vYmFy" - ) - SymmetricKeyAuthenticationProvider.parse(connection_string) - - -def test_raises_when_auth_provider_created_from_connection_string_without_proper_delimeter(): - with pytest.raises( - ValueError, - match="Connection string is required and should not be empty or blank and must be supplied as a string", - ): - connection_string = "HostName+beauxbatons.academy-net!DeviceId+TheDeluminator!" - SymmetricKeyAuthenticationProvider.parse(connection_string) - - -def test_raises_when_auth_provider_created_from_connection_string_with_bad_keys(): - with pytest.raises(ValueError, match="Invalid Connection String - Invalid Key"): - connection_string = "BadHostName=beauxbatons.academy-net;BadDeviceId=TheDeluminator;SharedAccessKey=Zm9vYmFy" - SymmetricKeyAuthenticationProvider.parse(connection_string) diff --git a/azure-iot-device/tests/iothub/auth/test_x509_authentication_provider.py b/azure-iot-device/tests/iothub/auth/test_x509_authentication_provider.py deleted file mode 100644 index 529d9563a..000000000 --- a/azure-iot-device/tests/iothub/auth/test_x509_authentication_provider.py +++ /dev/null @@ -1,67 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import pytest -import logging -from azure.iot.device.iothub.auth.x509_authentication_provider import X509AuthenticationProvider -from azure.iot.device.common.models.x509 import X509 - - -logging.basicConfig(level=logging.DEBUG) - -hostname = "beauxbatons.academy-net" -device_id = "MyPensieve" -fake_x509_cert_value = "fantastic_beasts" -fake_x509_cert_key = "where_to_find_them" -fake_pass_phrase = "alohomora" -module_id = "Transfiguration" - - -def x509(): - return X509(fake_x509_cert_value, fake_x509_cert_key, fake_pass_phrase) - - -@pytest.mark.describe("X509AuthenticationProvider") -class TestX509AuthenticationProvider(object): - @pytest.mark.it("Instantiates with hostname") - def test_instantiates_correctly_with_hostname(self): - x509_cert_object = x509() - x509_auth_provider = X509AuthenticationProvider( - x509=x509_cert_object, hostname=hostname, device_id=device_id - ) - assert x509_auth_provider.hostname == hostname - - @pytest.mark.it("Instantiates with device_id") - def test_instantiates_correctly_with_device_id(self): - x509_cert_object = x509() - x509_auth_provider = X509AuthenticationProvider( - x509=x509_cert_object, hostname=hostname, device_id=device_id - ) - assert x509_auth_provider.device_id == device_id - - @pytest.mark.it("Instantiates with module_id") - def test_instantiates_correctly_with_module_id(self): - x509_cert_object = x509() - x509_auth_provider = X509AuthenticationProvider( - x509=x509_cert_object, hostname=hostname, device_id=device_id, module_id=module_id - ) - assert x509_auth_provider.module_id == module_id - - @pytest.mark.it("Instantiates with module_id defaulting to None") - def test_instantiates_correctly_with_device_id_and_optional_module_id(self): - x509_cert_object = x509() - x509_auth_provider = X509AuthenticationProvider( - x509=x509_cert_object, hostname=hostname, device_id=device_id - ) - assert x509_auth_provider.module_id is None - - @pytest.mark.it("Getter retrieves the x509 certificate object") - def test_get_certificate(self): - x509_cert_object = x509() - x509_auth_provider = X509AuthenticationProvider( - x509=x509_cert_object, hostname=hostname, device_id=device_id - ) - assert x509_auth_provider.get_x509_certificate() diff --git a/azure-iot-device/tests/iothub/client_fixtures.py b/azure-iot-device/tests/iothub/client_fixtures.py index 99090947b..84c207c23 100644 --- a/azure-iot-device/tests/iothub/client_fixtures.py +++ b/azure-iot-device/tests/iothub/client_fixtures.py @@ -8,12 +8,7 @@ from azure.iot.device.iothub.pipeline import constant from azure.iot.device.iothub.models import Message, MethodResponse, MethodRequest from azure.iot.device.common.models.x509 import X509 -from azure.iot.device.iothub.auth import ( - SymmetricKeyAuthenticationProvider, - SharedAccessSignatureAuthenticationProvider, - IoTEdgeAuthenticationProvider, - X509AuthenticationProvider, -) + """---Constants---""" @@ -60,6 +55,11 @@ def twin_patch_reported(): return {"properties": {"reported": {"bar": 2}}} +@pytest.fixture +def fake_twin(): + return {"fake_twin": True} + + """----Shared connection string fixtures----""" device_connection_string_format = ( @@ -107,7 +107,7 @@ def module_connection_string(request): ) -"""----Shared sas token fixtures---""" +"""----Shared SAS fixtures---""" sas_token_format = "SharedAccessSignature sr={uri}&sig={signature}&se={expiry}" # when to use the skn format? @@ -170,6 +170,16 @@ def edge_local_debug_environment(): } +@pytest.fixture +def mock_edge_hsm(mocker): + mock_edge_hsm = mocker.patch("azure.iot.device.iothub.edge_hsm.IoTEdgeHsm") + mock_edge_hsm.return_value.sign.return_value = "8NJRMT83CcplGrAGaUVIUM/md5914KpWVNngSVoF9/M=" + mock_edge_hsm.return_value.get_certificate.return_value = ( + "__FAKE_SERVER_VERIFICATION_CERTIFICATE__" + ) + return mock_edge_hsm + + """----Shared mock pipeline fixture----""" @@ -192,7 +202,7 @@ def disable_feature(self, feature_name, callback): def send_message(self, event, callback): callback() - def send_output_event(self, event, callback): + def send_output_message(self, event, callback): callback() def send_method_response(self, method_response, callback): @@ -254,23 +264,10 @@ def http_pipeline_manual_cb(mocker): @pytest.fixture -def fake_twin(): - return {"fake_twin": True} - - -"""----Shared symmetric key fixtures----""" - - -@pytest.fixture -def symmetric_key(): - return shared_access_key - - -@pytest.fixture -def hostname_fixture(): - return hostname +def mock_mqtt_pipeline_init(mocker): + return mocker.patch("azure.iot.device.iothub.pipeline.MQTTPipeline") @pytest.fixture -def device_id_fixture(): - return device_id +def mock_http_pipeline_init(mocker): + return mocker.patch("azure.iot.device.iothub.pipeline.HTTPPipeline") diff --git a/azure-iot-device/tests/iothub/conftest.py b/azure-iot-device/tests/iothub/conftest.py index 770ef8217..c3d9348ee 100644 --- a/azure-iot-device/tests/iothub/conftest.py +++ b/azure-iot-device/tests/iothub/conftest.py @@ -14,10 +14,13 @@ method_request, twin_patch_desired, twin_patch_reported, + fake_twin, mqtt_pipeline, mqtt_pipeline_manual_cb, http_pipeline, http_pipeline_manual_cb, + mock_mqtt_pipeline_init, + mock_http_pipeline_init, device_connection_string, module_connection_string, device_sas_token_string, @@ -25,10 +28,7 @@ edge_container_environment, edge_local_debug_environment, x509, - fake_twin, - symmetric_key, - device_id_fixture, - hostname_fixture, + mock_edge_hsm, ) collect_ignore = [] diff --git a/azure-iot-device/tests/iothub/pipeline/conftest.py b/azure-iot-device/tests/iothub/pipeline/conftest.py index f0fbc1a79..eacfb1488 100644 --- a/azure-iot-device/tests/iothub/pipeline/conftest.py +++ b/azure-iot-device/tests/iothub/pipeline/conftest.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- +import pytest from tests.common.pipeline.fixtures import ( fake_pipeline_thread, fake_non_pipeline_thread, @@ -11,3 +12,21 @@ arbitrary_op, arbitrary_event, ) + +from azure.iot.device.iothub.pipeline import constant + +# Update this list with features as they are added to the SDK +# NOTE: should this be refactored into a fixture so it doesn't have to be imported? +# Is this used anywhere that DOESN'T just turn it into a fixture? +all_features = [ + constant.C2D_MSG, + constant.INPUT_MSG, + constant.METHODS, + constant.TWIN, + constant.TWIN_PATCHES, +] + + +@pytest.fixture(params=all_features) +def iothub_pipeline_feature(request): + return request.param diff --git a/azure-iot-device/tests/iothub/pipeline/helpers.py b/azure-iot-device/tests/iothub/pipeline/helpers.py deleted file mode 100644 index 6b28dba49..000000000 --- a/azure-iot-device/tests/iothub/pipeline/helpers.py +++ /dev/null @@ -1,20 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -from azure.iot.device.iothub.pipeline import pipeline_events_iothub, pipeline_ops_iothub - -all_iothub_ops = [ - pipeline_ops_iothub.SetAuthProviderOperation, - pipeline_ops_iothub.SetIoTHubConnectionArgsOperation, - pipeline_ops_iothub.SendD2CMessageOperation, - pipeline_ops_iothub.SendOutputEventOperation, -] - - -all_iothub_events = [ - pipeline_events_iothub.C2DMessageEvent, - pipeline_events_iothub.InputMessageEvent, - pipeline_events_iothub.MethodRequestEvent, -] diff --git a/azure-iot-device/tests/iothub/pipeline/test_config.py b/azure-iot-device/tests/iothub/pipeline/test_config.py index 5957dd556..b44ac675a 100644 --- a/azure-iot-device/tests/iothub/pipeline/test_config.py +++ b/azure-iot-device/tests/iothub/pipeline/test_config.py @@ -5,39 +5,78 @@ # -------------------------------------------------------------------------- import pytest import logging -from tests.common.pipeline.pipeline_config_test import PipelineConfigInstantiationTestBase +from tests.common.pipeline.config_test import PipelineConfigInstantiationTestBase from azure.iot.device.iothub.pipeline.config import IoTHubPipelineConfig +device_id = "my_device" +module_id = "my_module" +hostname = "hostname.some-domain.net" +product_info = "some_info" + @pytest.mark.describe("IoTHubPipelineConfig - Instantiation") class TestIoTHubPipelineConfigInstantiation(PipelineConfigInstantiationTestBase): + + # This fixture is needed for tests inherited from the parent class @pytest.fixture def config_cls(self): - # This fixture is needed for the parent class return IoTHubPipelineConfig + # This fixture is needed for tests inherited from the parent class + @pytest.fixture + def required_kwargs(self): + return {"device_id": device_id, "hostname": hostname} + + # The parent class defines the auth mechanism fixtures (sastoken, x509). + # For the sake of ease of testing, we will assume sastoken is being used unless + # there is a strict need to do something else. + # It does not matter which is used for the purposes of these tests. + @pytest.mark.it( - "Instantiates with the 'product_info' attribute set to the provided 'product_info' parameter" + "Instantiates with the 'device_id' attribute set to the provided 'device_id' paramater" ) - def test_product_info_set(self): - my_product_info = "some_info" - config = IoTHubPipelineConfig(product_info=my_product_info) + def test_device_id_set(self, sastoken): + config = IoTHubPipelineConfig(device_id=device_id, hostname=hostname, sastoken=sastoken) + assert config.device_id == device_id - assert config.product_info == my_product_info + @pytest.mark.it( + "Instantiates with the 'module_id' attribute set to the provided 'module_id' paramater" + ) + def test_module_id_set(self, sastoken): + config = IoTHubPipelineConfig( + device_id=device_id, module_id=module_id, hostname=hostname, sastoken=sastoken + ) + assert config.module_id == module_id + + @pytest.mark.it( + "Instantiates with the 'module_id' attribute set to 'None' if no 'module_id' paramater is provided" + ) + def test_module_id_default(self, sastoken): + config = IoTHubPipelineConfig(device_id=device_id, hostname=hostname, sastoken=sastoken) + assert config.module_id is None + + @pytest.mark.it( + "Instantiates with the 'product_info' attribute set to the provided 'product_info' parameter" + ) + def test_product_info_set(self, sastoken): + config = IoTHubPipelineConfig( + device_id=device_id, hostname=hostname, product_info=product_info, sastoken=sastoken + ) + assert config.product_info == product_info @pytest.mark.it( - "Instantiates with the 'product_info' attribute defaulting to empty string if there is no provided 'product_info'" + "Instantiates with the 'product_info' attribute defaulting to empty string if no 'product_info' paramater is provided" ) - def test_product_info_default(self): - config = IoTHubPipelineConfig() + def test_product_info_default(self, sastoken): + config = IoTHubPipelineConfig(device_id=device_id, hostname=hostname, sastoken=sastoken) assert config.product_info == "" @pytest.mark.it("Instantiates with the 'blob_upload' attribute set to False") - def test_blob_upload(self): - config = IoTHubPipelineConfig() + def test_blob_upload(self, sastoken): + config = IoTHubPipelineConfig(device_id=device_id, hostname=hostname, sastoken=sastoken) assert config.blob_upload is False @pytest.mark.it("Instantiates with the 'method_invoke' attribute set to False") - def test_method_invoke(self): - config = IoTHubPipelineConfig() + def test_method_invoke(self, sastoken): + config = IoTHubPipelineConfig(device_id=device_id, hostname=hostname, sastoken=sastoken) assert config.method_invoke is False diff --git a/azure-iot-device/tests/iothub/pipeline/test_http_pipeline.py b/azure-iot-device/tests/iothub/pipeline/test_http_pipeline.py index 8109798cc..678bb8645 100644 --- a/azure-iot-device/tests/iothub/pipeline/test_http_pipeline.py +++ b/azure-iot-device/tests/iothub/pipeline/test_http_pipeline.py @@ -20,10 +20,6 @@ pipeline_ops_iothub_http, ) from azure.iot.device.iothub.pipeline import HTTPPipeline, constant -from azure.iot.device.iothub.auth import ( - SymmetricKeyAuthenticationProvider, - X509AuthenticationProvider, -) logging.basicConfig(level=logging.DEBUG) pytestmark = pytest.mark.usefixtures("fake_pipeline_thread") @@ -33,22 +29,18 @@ fake_blob_name = "__fake_blob_name__" -@pytest.fixture -def auth_provider(mocker): - return mocker.MagicMock() - - @pytest.fixture def pipeline_configuration(mocker): mocked_configuration = mocker.MagicMock() mocked_configuration.blob_upload = True mocked_configuration.method_invoke = True + mocked_configuration.sastoken.ttl = 1232 # set for compat return mocked_configuration @pytest.fixture -def pipeline(mocker, auth_provider, pipeline_configuration): - pipeline = HTTPPipeline(auth_provider, pipeline_configuration) +def pipeline(mocker, pipeline_configuration): + pipeline = HTTPPipeline(pipeline_configuration) mocker.patch.object(pipeline._pipeline, "run_op") return pipeline @@ -61,7 +53,6 @@ def twin_patch(): # automatically mock the transport for all tests in this file. @pytest.fixture(autouse=True) def mock_transport(mocker): - print("mocking transport") mocker.patch( "azure.iot.device.common.pipeline.pipeline_stages_http.HTTPTransport", autospec=True ) @@ -70,13 +61,13 @@ def mock_transport(mocker): @pytest.mark.describe("HTTPPipeline - Instantiation") class TestHTTPPipelineInstantiation(object): @pytest.mark.it("Configures the pipeline with a series of PipelineStages") - def test_pipeline_configuration(self, auth_provider, pipeline_configuration): - pipeline = HTTPPipeline(auth_provider, pipeline_configuration) + def test_pipeline_configuration(self, pipeline_configuration): + pipeline = HTTPPipeline(pipeline_configuration) curr_stage = pipeline._pipeline expected_stage_order = [ pipeline_stages_base.PipelineRootStage, - pipeline_stages_iothub.UseAuthProviderStage, + pipeline_stages_base.SasTokenRenewalStage, pipeline_stages_iothub_http.IoTHubHTTPTranslationStage, pipeline_stages_http.HTTPTransportStage, ] @@ -90,71 +81,24 @@ def test_pipeline_configuration(self, auth_provider, pipeline_configuration): # Assert there are no more additional stages assert curr_stage is None - # TODO: revist these tests after auth revision - # They are too tied to auth types (and there's too much variance in auths to effectively test) - # Ideally HTTPPipeline is entirely insulated from any auth differential logic (and module/device distinctions) - # In the meantime, we are using a device auth with connection string to stand in for generic SAS auth - # and device auth with X509 certs to stand in for generic X509 auth - @pytest.mark.it( - "Runs a SetAuthProviderOperation with the provided AuthenticationProvider on the pipeline, if using SAS based authentication" - ) - def test_sas_auth(self, mocker, device_connection_string, pipeline_configuration): + @pytest.mark.it("Runs an InitializePipelineOperation on the pipeline") + def test_sas_auth(self, mocker, pipeline_configuration): mocker.spy(pipeline_stages_base.PipelineRootStage, "run_op") - auth_provider = SymmetricKeyAuthenticationProvider.parse(device_connection_string) - pipeline = HTTPPipeline(auth_provider, pipeline_configuration) - op = pipeline._pipeline.run_op.call_args[0][1] - assert pipeline._pipeline.run_op.call_count == 1 - assert isinstance(op, pipeline_ops_iothub.SetAuthProviderOperation) - assert op.auth_provider is auth_provider - @pytest.mark.it( - "Raises exceptions that occurred in execution upon unsuccessful completion of the SetAuthProviderOperation" - ) - def test_sas_auth_op_fail( - self, mocker, device_connection_string, arbitrary_exception, pipeline_configuration - ): - old_run_op = pipeline_stages_base.PipelineRootStage._run_op - - def fail_set_auth_provider(self, op): - if isinstance(op, pipeline_ops_iothub.SetAuthProviderOperation): - op.complete(error=arbitrary_exception) - else: - old_run_op(self, op) + pipeline = HTTPPipeline(pipeline_configuration) - mocker.patch.object( - pipeline_stages_base.PipelineRootStage, - "_run_op", - side_effect=fail_set_auth_provider, - autospec=True, - ) - - auth_provider = SymmetricKeyAuthenticationProvider.parse(device_connection_string) - with pytest.raises(arbitrary_exception.__class__) as e_info: - HTTPPipeline(auth_provider, pipeline_configuration) - assert e_info.value is arbitrary_exception - - @pytest.mark.it( - "Runs a SetX509AuthProviderOperation with the provided AuthenticationProvider on the pipeline, if using SAS based authentication" - ) - def test_cert_auth(self, mocker, x509, pipeline_configuration): - mocker.spy(pipeline_stages_base.PipelineRootStage, "run_op") - auth_provider = X509AuthenticationProvider( - hostname="somehostname", device_id=fake_device_id, x509=x509 - ) - pipeline = HTTPPipeline(auth_provider, pipeline_configuration) op = pipeline._pipeline.run_op.call_args[0][1] assert pipeline._pipeline.run_op.call_count == 1 - assert isinstance(op, pipeline_ops_iothub.SetX509AuthProviderOperation) - assert op.auth_provider is auth_provider + assert isinstance(op, pipeline_ops_base.InitializePipelineOperation) @pytest.mark.it( - "Raises exceptions that occurred in execution upon unsuccessful completion of the SetX509AuthProviderOperation" + "Raises exceptions that occurred in execution upon unsuccessful completion of the InitializePipelineOperation" ) - def test_cert_auth_op_fail(self, mocker, x509, arbitrary_exception, pipeline_configuration): + def test_sas_auth_op_fail(self, mocker, arbitrary_exception, pipeline_configuration): old_run_op = pipeline_stages_base.PipelineRootStage._run_op - def fail_set_auth_provider(self, op): - if isinstance(op, pipeline_ops_iothub.SetX509AuthProviderOperation): + def fail_initialize(self, op): + if isinstance(op, pipeline_ops_base.InitializePipelineOperation): op.complete(error=arbitrary_exception) else: old_run_op(self, op) @@ -162,15 +106,13 @@ def fail_set_auth_provider(self, op): mocker.patch.object( pipeline_stages_base.PipelineRootStage, "_run_op", - side_effect=fail_set_auth_provider, + side_effect=fail_initialize, autospec=True, ) - auth_provider = X509AuthenticationProvider( - hostname="somehostname", device_id=fake_device_id, x509=x509 - ) - with pytest.raises(arbitrary_exception.__class__): - HTTPPipeline(auth_provider, pipeline_configuration) + with pytest.raises(arbitrary_exception.__class__) as e_info: + HTTPPipeline(pipeline_configuration) + assert e_info.value is arbitrary_exception @pytest.mark.describe("HTTPPipeline - .invoke_method()") diff --git a/azure-iot-device/tests/iothub/pipeline/test_mqtt_pipeline.py b/azure-iot-device/tests/iothub/pipeline/test_mqtt_pipeline.py index a8a7d3635..7cc63e202 100644 --- a/azure-iot-device/tests/iothub/pipeline/test_mqtt_pipeline.py +++ b/azure-iot-device/tests/iothub/pipeline/test_mqtt_pipeline.py @@ -14,6 +14,7 @@ pipeline_ops_base, ) from azure.iot.device.iothub.pipeline import ( + config, pipeline_stages_iothub, pipeline_stages_iothub_mqtt, pipeline_ops_iothub, @@ -21,42 +22,28 @@ ) from azure.iot.device.iothub import Message from azure.iot.device.iothub.pipeline import MQTTPipeline, constant -from azure.iot.device.iothub.auth import ( - SymmetricKeyAuthenticationProvider, - X509AuthenticationProvider, -) +from .conftest import all_features logging.basicConfig(level=logging.DEBUG) pytestmark = pytest.mark.usefixtures("fake_pipeline_thread") -# Update this list with features as they are added to the SDK -all_features = [ - constant.C2D_MSG, - constant.INPUT_MSG, - constant.METHODS, - constant.TWIN, - constant.TWIN_PATCHES, -] - - -@pytest.fixture -def auth_provider(mocker): - auth = mocker.MagicMock() - # Add values so that it doesn't break down the pipeline. - # This will no longer be needed after auth revisions. - auth.device_id = "fake_device" - auth.module_id = None - return auth - @pytest.fixture def pipeline_configuration(mocker): - return mocker.MagicMock() + # NOTE: Consider parametrizing this to serve as both a device and module configuration. + # The reason this isn't currently done is that it's not strictly necessary, but it might be + # more correct and complete to do so. Certainly this must be done if any device/module + # specific logic is added to the code under test. + mock_config = config.IoTHubPipelineConfig( + device_id="my_device", hostname="my.host.name", sastoken=mocker.MagicMock() + ) + mock_config.sastoken.ttl = 1232 # set for compat + return mock_config @pytest.fixture -def pipeline(mocker, auth_provider, pipeline_configuration): - pipeline = MQTTPipeline(auth_provider, pipeline_configuration) +def pipeline(mocker, pipeline_configuration): + pipeline = MQTTPipeline(pipeline_configuration) mocker.patch.object(pipeline._pipeline, "run_op") return pipeline @@ -69,8 +56,7 @@ def twin_patch(): # automatically mock the transport for all tests in this file. @pytest.fixture(autouse=True) def mock_transport(mocker): - print("mocking transport") - mocker.patch( + return mocker.patch( "azure.iot.device.common.pipeline.pipeline_stages_mqtt.MQTTTransport", autospec=True ) @@ -79,20 +65,20 @@ def mock_transport(mocker): class TestMQTTPipelineInstantiation(object): @pytest.mark.it("Begins tracking the enabled/disabled status of features") @pytest.mark.parametrize("feature", all_features) - def test_features(self, auth_provider, pipeline_configuration, feature): - pipeline = MQTTPipeline(auth_provider, pipeline_configuration) + def test_features(self, pipeline_configuration, feature): + pipeline = MQTTPipeline(pipeline_configuration) pipeline.feature_enabled[feature] # No assertion required - if this doesn't raise a KeyError, it is a success @pytest.mark.it("Marks all features as disabled") - def test_features_disabled(self, auth_provider, pipeline_configuration): - pipeline = MQTTPipeline(auth_provider, pipeline_configuration) + def test_features_disabled(self, pipeline_configuration): + pipeline = MQTTPipeline(pipeline_configuration) for key in pipeline.feature_enabled: assert not pipeline.feature_enabled[key] @pytest.mark.it("Sets all handlers to an initial value of None") - def test_handlers_set_to_none(self, auth_provider, pipeline_configuration): - pipeline = MQTTPipeline(auth_provider, pipeline_configuration) + def test_handlers_set_to_none(self, pipeline_configuration): + pipeline = MQTTPipeline(pipeline_configuration) assert pipeline.on_connected is None assert pipeline.on_disconnected is None assert pipeline.on_c2d_message_received is None @@ -101,20 +87,20 @@ def test_handlers_set_to_none(self, auth_provider, pipeline_configuration): assert pipeline.on_twin_patch_received is None @pytest.mark.it("Configures the pipeline to trigger handlers in response to external events") - def test_handlers_configured(self, auth_provider, pipeline_configuration): - pipeline = MQTTPipeline(auth_provider, pipeline_configuration) + def test_handlers_configured(self, pipeline_configuration): + pipeline = MQTTPipeline(pipeline_configuration) assert pipeline._pipeline.on_pipeline_event_handler is not None assert pipeline._pipeline.on_connected_handler is not None assert pipeline._pipeline.on_disconnected_handler is not None @pytest.mark.it("Configures the pipeline with a series of PipelineStages") - def test_pipeline_configuration(self, auth_provider, pipeline_configuration): - pipeline = MQTTPipeline(auth_provider, pipeline_configuration) + def test_pipeline_configuration(self, pipeline_configuration): + pipeline = MQTTPipeline(pipeline_configuration) curr_stage = pipeline._pipeline expected_stage_order = [ pipeline_stages_base.PipelineRootStage, - pipeline_stages_iothub.UseAuthProviderStage, + pipeline_stages_base.SasTokenRenewalStage, pipeline_stages_iothub.EnsureDesiredPropertiesStage, pipeline_stages_iothub.TwinRequestResponseStage, pipeline_stages_base.CoordinateRequestAndResponseStage, @@ -136,71 +122,24 @@ def test_pipeline_configuration(self, auth_provider, pipeline_configuration): # Assert there are no more additional stages assert curr_stage is None - # TODO: revist these tests after auth revision - # They are too tied to auth types (and there's too much variance in auths to effectively test) - # Ideally MQTTPipeline is entirely insulated from any auth differential logic (and module/device distinctions) - # In the meantime, we are using a device auth with connection string to stand in for generic SAS auth - # and device auth with X509 certs to stand in for generic X509 auth - @pytest.mark.it( - "Runs a SetAuthProviderOperation with the provided AuthenticationProvider on the pipeline, if using SAS based authentication" - ) - def test_sas_auth(self, mocker, device_connection_string, pipeline_configuration): + @pytest.mark.it("Runs an InitializePipelineOperation on the pipeline") + def test_init_pipeline(self, mocker, pipeline_configuration): mocker.spy(pipeline_stages_base.PipelineRootStage, "run_op") - auth_provider = SymmetricKeyAuthenticationProvider.parse(device_connection_string) - pipeline = MQTTPipeline(auth_provider, pipeline_configuration) - op = pipeline._pipeline.run_op.call_args[0][1] - assert pipeline._pipeline.run_op.call_count == 1 - assert isinstance(op, pipeline_ops_iothub.SetAuthProviderOperation) - assert op.auth_provider is auth_provider - - @pytest.mark.it( - "Raises exceptions that occurred in execution upon unsuccessful completion of the SetAuthProviderOperation" - ) - def test_sas_auth_op_fail( - self, mocker, device_connection_string, arbitrary_exception, pipeline_configuration - ): - old_run_op = pipeline_stages_base.PipelineRootStage._run_op - - def fail_set_auth_provider(self, op): - if isinstance(op, pipeline_ops_iothub.SetAuthProviderOperation): - op.complete(error=arbitrary_exception) - else: - old_run_op(self, op) - mocker.patch.object( - pipeline_stages_base.PipelineRootStage, - "_run_op", - side_effect=fail_set_auth_provider, - autospec=True, - ) + pipeline = MQTTPipeline(pipeline_configuration) - auth_provider = SymmetricKeyAuthenticationProvider.parse(device_connection_string) - with pytest.raises(arbitrary_exception.__class__) as e_info: - MQTTPipeline(auth_provider, pipeline_configuration) - assert e_info.value is arbitrary_exception - - @pytest.mark.it( - "Runs a SetX509AuthProviderOperation with the provided AuthenticationProvider on the pipeline, if using SAS based authentication" - ) - def test_cert_auth(self, mocker, x509, pipeline_configuration): - mocker.spy(pipeline_stages_base.PipelineRootStage, "run_op") - auth_provider = X509AuthenticationProvider( - hostname="somehostname", device_id="somedevice", x509=x509 - ) - pipeline = MQTTPipeline(auth_provider, pipeline_configuration) op = pipeline._pipeline.run_op.call_args[0][1] assert pipeline._pipeline.run_op.call_count == 1 - assert isinstance(op, pipeline_ops_iothub.SetX509AuthProviderOperation) - assert op.auth_provider is auth_provider + assert isinstance(op, pipeline_ops_base.InitializePipelineOperation) @pytest.mark.it( - "Raises exceptions that occurred in execution upon unsuccessful completion of the SetX509AuthProviderOperation" + "Raises exceptions that occurred in execution upon unsuccessful completion of the InitializePipelineOperation" ) - def test_cert_auth_op_fail(self, mocker, x509, arbitrary_exception, pipeline_configuration): + def test_init_pipeline_fail(self, mocker, arbitrary_exception, pipeline_configuration): old_run_op = pipeline_stages_base.PipelineRootStage._run_op - def fail_set_auth_provider(self, op): - if isinstance(op, pipeline_ops_iothub.SetX509AuthProviderOperation): + def fail_initialize(self, op): + if isinstance(op, pipeline_ops_base.InitializePipelineOperation): op.complete(error=arbitrary_exception) else: old_run_op(self, op) @@ -208,15 +147,13 @@ def fail_set_auth_provider(self, op): mocker.patch.object( pipeline_stages_base.PipelineRootStage, "_run_op", - side_effect=fail_set_auth_provider, + side_effect=fail_initialize, autospec=True, ) - auth_provider = X509AuthenticationProvider( - hostname="somehostname", device_id="somedevice", x509=x509 - ) - with pytest.raises(arbitrary_exception.__class__): - MQTTPipeline(auth_provider, pipeline_configuration) + with pytest.raises(arbitrary_exception.__class__) as e_info: + MQTTPipeline(pipeline_configuration) + assert e_info.value is arbitrary_exception @pytest.mark.describe("MQTTPipeline - .connect()") @@ -340,31 +277,31 @@ def test_op_fail(self, mocker, pipeline, message, arbitrary_exception): assert cb.call_args == mocker.call(error=arbitrary_exception) -@pytest.mark.describe("MQTTPipeline - .send_output_event()") -class TestMQTTPipelineSendOutputEvent(object): +@pytest.mark.describe("MQTTPipeline - .send_output_message()") +class TestMQTTPipelineSendOutputMessage(object): @pytest.fixture def message(self, message): """Modify message fixture to have an output""" message.output_name = "some output" return message - @pytest.mark.it("Runs a SendOutputEventOperation with the provided Message on the pipeline") + @pytest.mark.it("Runs a SendOutputMessageOperation with the provided Message on the pipeline") def test_runs_op(self, pipeline, message, mocker): - pipeline.send_output_event(message, callback=mocker.MagicMock()) + pipeline.send_output_message(message, callback=mocker.MagicMock()) op = pipeline._pipeline.run_op.call_args[0][0] assert pipeline._pipeline.run_op.call_count == 1 - assert isinstance(op, pipeline_ops_iothub.SendOutputEventOperation) + assert isinstance(op, pipeline_ops_iothub.SendOutputMessageOperation) assert op.message == message @pytest.mark.it( - "Triggers the callback upon successful completion of the SendOutputEventOperation" + "Triggers the callback upon successful completion of the SendOutputMessageOperation" ) def test_op_success_with_callback(self, mocker, pipeline, message): cb = mocker.MagicMock() # Begin operation - pipeline.send_output_event(message, callback=cb) + pipeline.send_output_message(message, callback=cb) assert cb.call_count == 0 # Trigger op completion callback @@ -375,11 +312,11 @@ def test_op_success_with_callback(self, mocker, pipeline, message): assert cb.call_args == mocker.call(error=None) @pytest.mark.it( - "Calls the callback with the error upon unsuccessful completion of the SendOutputEventOperation" + "Calls the callback with the error upon unsuccessful completion of the SendOutputMessageOperation" ) def test_op_fail(self, mocker, pipeline, message, arbitrary_exception): cb = mocker.MagicMock() - pipeline.send_output_event(message, callback=cb) + pipeline.send_output_message(message, callback=cb) op = pipeline._pipeline.run_op.call_args[0][0] op.complete(error=arbitrary_exception) diff --git a/azure-iot-device/tests/iothub/pipeline/test_pipeline_ops_iothub.py b/azure-iot-device/tests/iothub/pipeline/test_pipeline_ops_iothub.py index ecb39f6f3..5476ea803 100644 --- a/azure-iot-device/tests/iothub/pipeline/test_pipeline_ops_iothub.py +++ b/azure-iot-device/tests/iothub/pipeline/test_pipeline_ops_iothub.py @@ -14,173 +14,6 @@ pytestmark = pytest.mark.usefixtures("fake_pipeline_thread") -class SetAuthProviderOperationTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_ops_iothub.SetAuthProviderOperation - - @pytest.fixture - def init_kwargs(self, mocker): - kwargs = {"auth_provider": mocker.MagicMock(), "callback": mocker.MagicMock()} - return kwargs - - -class SetAuthProviderOperationInstantiationTests(SetAuthProviderOperationTestConfig): - @pytest.mark.it( - "Initializes 'auth_provider' attribute with the provided 'auth_provider' parameter" - ) - def test_auth_provider(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.auth_provider is init_kwargs["auth_provider"] - - -pipeline_ops_test.add_operation_tests( - test_module=this_module, - op_class_under_test=pipeline_ops_iothub.SetAuthProviderOperation, - op_test_config_class=SetAuthProviderOperationTestConfig, - extended_op_instantiation_test_class=SetAuthProviderOperationInstantiationTests, -) - - -class SetX509AuthProviderOperationTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_ops_iothub.SetX509AuthProviderOperation - - @pytest.fixture - def init_kwargs(self, mocker): - kwargs = {"auth_provider": mocker.MagicMock(), "callback": mocker.MagicMock()} - return kwargs - - -class SetX509AuthProviderOperationInstantiationTests(SetX509AuthProviderOperationTestConfig): - @pytest.mark.it( - "Initializes 'auth_provider' attribute with the provided 'auth_provider' parameter" - ) - def test_auth_provider(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.auth_provider is init_kwargs["auth_provider"] - - -pipeline_ops_test.add_operation_tests( - test_module=this_module, - op_class_under_test=pipeline_ops_iothub.SetX509AuthProviderOperation, - op_test_config_class=SetX509AuthProviderOperationTestConfig, - extended_op_instantiation_test_class=SetX509AuthProviderOperationInstantiationTests, -) - - -class SetIoTHubConnectionArgsOperationTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_ops_iothub.SetIoTHubConnectionArgsOperation - - @pytest.fixture - def init_kwargs(self, mocker): - kwargs = { - "device_id": "some_device_id", - "hostname": "some_hostname", - "callback": mocker.MagicMock(), - "module_id": "some_module_id", - "gateway_hostname": "some_gateway_hostname", - "server_verification_cert": "some_server_verification_cert", - "client_cert": "some_client_cert", - "sas_token": "some_sas_token", - } - return kwargs - - -class SetIoTHubConnectionArgsOperationInstantiationTests( - SetIoTHubConnectionArgsOperationTestConfig -): - @pytest.mark.it("Initializes 'device_id' attribute with the provided 'device_id' parameter") - def test_device_id(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.device_id == init_kwargs["device_id"] - - @pytest.mark.it("Initializes 'hostname' attribute with the provided 'hostname' parameter") - def test_hostname(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.hostname == init_kwargs["hostname"] - - @pytest.mark.it("Initializes 'module_id' attribute with the provided 'module_id' parameter") - def test_module_id(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.module_id == init_kwargs["module_id"] - - @pytest.mark.it( - "Initializes 'module_id' attribute to None if no 'module_id' parameter is provided" - ) - def test_module_id_default(self, cls_type, init_kwargs): - del init_kwargs["module_id"] - op = cls_type(**init_kwargs) - assert op.module_id is None - - @pytest.mark.it( - "Initializes 'gateway_hostname' attribute with the provided 'gateway_hostname' parameter" - ) - def test_gateway_hostname(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.gateway_hostname == init_kwargs["gateway_hostname"] - - @pytest.mark.it( - "Initializes 'gateway_hostname' attribute to None if no 'gateway_hostname' parameter is provided" - ) - def test_gateway_hostname_default(self, cls_type, init_kwargs): - del init_kwargs["gateway_hostname"] - op = cls_type(**init_kwargs) - assert op.gateway_hostname is None - - @pytest.mark.it( - "Initializes 'server_verification_cert' attribute with the provided 'server_verification_cert' parameter" - ) - def test_server_verification_cert(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.server_verification_cert == init_kwargs["server_verification_cert"] - - @pytest.mark.it( - "Initializes 'server_verification_cert' attribute to None if no 'server_verification_cert' parameter is provided" - ) - def test_server_verification_cert_default(self, cls_type, init_kwargs): - del init_kwargs["server_verification_cert"] - op = cls_type(**init_kwargs) - assert op.server_verification_cert is None - - @pytest.mark.it("Initializes 'client_cert' attribute with the provided 'client_cert' parameter") - def test_client_cert(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.client_cert == init_kwargs["client_cert"] - - @pytest.mark.it( - "Initializes 'client_cert' attribute to None if no 'client_cert' parameter is provided" - ) - def test_client_cert_default(self, cls_type, init_kwargs): - del init_kwargs["client_cert"] - op = cls_type(**init_kwargs) - assert op.client_cert is None - - @pytest.mark.it("Initializes 'sas_token' attribute with the provided 'sas_token' parameter") - def test_sas_token(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.sas_token == init_kwargs["sas_token"] - - @pytest.mark.it( - "Initializes 'sas_token' attribute to None if no 'sas_token' parameter is provided" - ) - def test_sas_token_default(self, cls_type, init_kwargs): - del init_kwargs["sas_token"] - op = cls_type(**init_kwargs) - assert op.sas_token is None - - -pipeline_ops_test.add_operation_tests( - test_module=this_module, - op_class_under_test=pipeline_ops_iothub.SetIoTHubConnectionArgsOperation, - op_test_config_class=SetIoTHubConnectionArgsOperationTestConfig, - extended_op_instantiation_test_class=SetIoTHubConnectionArgsOperationInstantiationTests, -) - - class SendD2CMessageOperationTestConfig(object): @pytest.fixture def cls_type(self): @@ -207,10 +40,10 @@ def test_message(self, cls_type, init_kwargs): ) -class SendOutputEventOperationTestConfig(object): +class SendOutputMessageOperationTestConfig(object): @pytest.fixture def cls_type(self): - return pipeline_ops_iothub.SendOutputEventOperation + return pipeline_ops_iothub.SendOutputMessageOperation @pytest.fixture def init_kwargs(self, mocker): @@ -218,7 +51,7 @@ def init_kwargs(self, mocker): return kwargs -class SendOutputEventOperationInstantiationTests(SendOutputEventOperationTestConfig): +class SendOutputMessageOperationInstantiationTests(SendOutputMessageOperationTestConfig): @pytest.mark.it("Initializes 'message' attribute with the provided 'message' parameter") def test_message(self, cls_type, init_kwargs): op = cls_type(**init_kwargs) @@ -227,9 +60,9 @@ def test_message(self, cls_type, init_kwargs): pipeline_ops_test.add_operation_tests( test_module=this_module, - op_class_under_test=pipeline_ops_iothub.SendOutputEventOperation, - op_test_config_class=SendOutputEventOperationTestConfig, - extended_op_instantiation_test_class=SendOutputEventOperationInstantiationTests, + op_class_under_test=pipeline_ops_iothub.SendOutputMessageOperation, + op_test_config_class=SendOutputMessageOperationTestConfig, + extended_op_instantiation_test_class=SendOutputMessageOperationInstantiationTests, ) diff --git a/azure-iot-device/tests/iothub/pipeline/test_pipeline_stages_iothub.py b/azure-iot-device/tests/iothub/pipeline/test_pipeline_stages_iothub.py index 483d2e748..03edbc5f1 100644 --- a/azure-iot-device/tests/iothub/pipeline/test_pipeline_stages_iothub.py +++ b/azure-iot-device/tests/iothub/pipeline/test_pipeline_stages_iothub.py @@ -3,28 +3,22 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- -import functools import json import logging import pytest import sys -import threading -from concurrent.futures import Future from azure.iot.device.exceptions import ServiceError from azure.iot.device.common import handle_exceptions -from azure.iot.device.common.pipeline import pipeline_events_base, pipeline_ops_base from azure.iot.device.iothub.pipeline import ( pipeline_events_iothub, pipeline_ops_iothub, pipeline_stages_iothub, constant as pipeline_constants, ) +from azure.iot.device.common.pipeline import pipeline_events_base, pipeline_ops_base from azure.iot.device.iothub.pipeline.exceptions import PipelineError -from azure.iot.device.iothub.auth.authentication_provider import AuthenticationProvider from tests.common.pipeline.helpers import StageRunOpTestBase, StageHandlePipelineEventTestBase from tests.common.pipeline import pipeline_stage_test -from azure.iot.device.common.models.x509 import X509 -from azure.iot.device.iothub.auth.x509_authentication_provider import X509AuthenticationProvider logging.basicConfig(level=logging.DEBUG) this_module = sys.modules[__name__] @@ -62,327 +56,6 @@ def mock_handle_background_exception(mocker): return mock_handler -########################### -# USE AUTH PROVIDER STAGE # -########################### - - -class UseAuthProviderStageTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_stages_iothub.UseAuthProviderStage - - @pytest.fixture - def init_kwargs(self): - return {} - - @pytest.fixture - def stage(self, mocker, cls_type, init_kwargs): - stage = cls_type(**init_kwargs) - stage.send_op_down = mocker.MagicMock() - stage.send_event_up = mocker.MagicMock() - return stage - - -class UseAuthProviderStageInstantiationTests(UseAuthProviderStageTestConfig): - @pytest.mark.it("Initializes 'auth_provider' as None") - def test_auth_provider(self, init_kwargs): - stage = pipeline_stages_iothub.UseAuthProviderStage(**init_kwargs) - assert stage.auth_provider is None - - -pipeline_stage_test.add_base_pipeline_stage_tests( - test_module=this_module, - stage_class_under_test=pipeline_stages_iothub.UseAuthProviderStage, - stage_test_config_class=UseAuthProviderStageTestConfig, - extended_stage_instantiation_test_class=UseAuthProviderStageInstantiationTests, -) - - -@pytest.mark.describe( - "UseAuthProviderStage - .run_op() -- Called with SetAuthProviderOperation (SAS Authentication)" -) -class TestUseAuthProviderStageRunOpWithSetAuthProviderOperation( - StageRunOpTestBase, UseAuthProviderStageTestConfig -): - # Auth Providers are configured with different values depending on if the higher level client - # is a Device or Module. Parametrize with both possibilities. - # TODO: Eventually would be ideal to test using real auth provider instead of the fake one - # This probably should just wait until auth provider refactor for ease though. - @pytest.fixture(params=["Device", "Module"]) - def fake_auth_provider(self, request, mocker): - class FakeAuthProvider(AuthenticationProvider): - pass - - if request.param == "Device": - fake_auth_provider = FakeAuthProvider(hostname=fake_hostname, device_id=fake_device_id) - else: - fake_auth_provider = FakeAuthProvider( - hostname=fake_hostname, device_id=fake_device_id, module_id=fake_module_id - ) - fake_auth_provider.get_current_sas_token = mocker.MagicMock() - fake_auth_provider.on_sas_token_updated_handler_list = [mocker.MagicMock()] - return fake_auth_provider - - @pytest.fixture - def op(self, mocker, fake_auth_provider): - return pipeline_ops_iothub.SetAuthProviderOperation( - auth_provider=fake_auth_provider, callback=mocker.MagicMock() - ) - - @pytest.mark.it( - "Sets the operation's authentication provider on the stage as the 'auth_provider' attribute" - ) - def test_set_auth_provider(self, op, stage): - assert stage.auth_provider is None - - stage.run_op(op) - - assert stage.auth_provider is op.auth_provider - - # NOTE: Because currently auth providers don't have a consistent attribute surface, only some - # have the 'server_verification_cert' and 'gateway_hostname' attributes, so parametrize to show they default to - # None when non-existent. If authentication providers ever receive a uniform surface, this - # parametrization will no longer be required. - @pytest.mark.it( - "Sends a new SetIoTHubConnectionArgsOperation op down the pipeline, containing connection info from the authentication provider" - ) - @pytest.mark.parametrize( - "all_auth_args", [True, False], ids=["All authentication args", "Only guaranteed args"] - ) - def test_send_new_op_down(self, mocker, op, stage, all_auth_args): - if all_auth_args: - op.auth_provider.server_verification_cert = fake_server_verification_cert - op.auth_provider.gateway_hostname = fake_gateway_hostname - - stage.run_op(op) - - # A SetIoTHubConnectionArgsOperation op has been sent down the pipeline - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_iothub.SetIoTHubConnectionArgsOperation) - - # The IoTHubConnectionArgsOperation has details from the auth provider - assert new_op.device_id == op.auth_provider.device_id - assert new_op.module_id == op.auth_provider.module_id - assert new_op.hostname == op.auth_provider.hostname - assert new_op.sas_token is op.auth_provider.get_current_sas_token.return_value - assert new_op.client_cert is None - if all_auth_args: - assert new_op.server_verification_cert == op.auth_provider.server_verification_cert - assert new_op.gateway_hostname == op.auth_provider.gateway_hostname - else: - assert new_op.server_verification_cert is None - assert new_op.gateway_hostname is None - - @pytest.mark.it( - "Completes the original operation upon completion of the SetIoTHubConnectionArgsOperation" - ) - def test_complete_worker(self, op, stage, op_error): - # Run original op - stage.run_op(op) - assert not op.completed - - # A SetIoTHubConnectionArgsOperation op has been sent down the pipeline - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_iothub.SetIoTHubConnectionArgsOperation) - assert not new_op.completed - - # Complete the new op - new_op.complete(error=op_error) - - # Both ops are now completed - assert new_op.completed - assert new_op.error is op_error - assert op.completed - assert op.error is op_error - - -@pytest.mark.describe( - "UseAuthProviderStage - .run_op() -- Called with SetX509AuthProviderOperation (X509 Authentication)" -) -class TestUseAuthProviderStageRunOpWithSetX509AuthProviderOperation( - StageRunOpTestBase, UseAuthProviderStageTestConfig -): - # Auth Providers are configured with different values depending on if the higher level client - # is a Device or Module. Parametrize with both possibilities. - # TODO: Eventually would be ideal to test using real auth provider instead of the fake one - # This probably should just wait until auth provider refactor for ease though. - @pytest.fixture(params=["Device", "Module"]) - def fake_auth_provider(self, request, mocker): - class FakeAuthProvider(AuthenticationProvider): - pass - - if request.param == "Device": - fake_auth_provider = FakeAuthProvider(hostname=fake_hostname, device_id=fake_device_id) - else: - fake_auth_provider = FakeAuthProvider( - hostname=fake_hostname, device_id=fake_device_id, module_id=fake_module_id - ) - fake_auth_provider.get_x509_certificate = mocker.MagicMock() - return fake_auth_provider - - @pytest.fixture - def op(self, mocker, fake_auth_provider): - return pipeline_ops_iothub.SetX509AuthProviderOperation( - auth_provider=fake_auth_provider, callback=mocker.MagicMock() - ) - - @pytest.mark.it( - "Sets the operation's authentication provider on the stage as the 'auth_provider' attribute" - ) - def test_set_auth_provider(self, op, stage): - assert stage.auth_provider is None - - stage.run_op(op) - - assert stage.auth_provider is op.auth_provider - - # NOTE: Because currently auth providers don't have a consistent attribute surface, only some - # have the 'server_verification_cert' and 'gateway_hostname' attributes, so parametrize to show they default to - # None when non-existent. If authentication providers ever receive a uniform surface, this - # parametrization will no longer be required. - @pytest.mark.it( - "Sends a new SetIoTHubConnectionArgsOperation op down the pipeline, containing connection info from the authentication provider" - ) - @pytest.mark.parametrize( - "all_auth_args", [True, False], ids=["All authentication args", "Only guaranteed args"] - ) - def test_send_new_op_down(self, mocker, op, stage, all_auth_args): - if all_auth_args: - op.auth_provider.server_verification_cert = fake_server_verification_cert - op.auth_provider.gateway_hostname = fake_gateway_hostname - - stage.run_op(op) - - # A SetIoTHubConnectionArgsOperation op has been sent down the pipeline - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_iothub.SetIoTHubConnectionArgsOperation) - - # The IoTHubConnectionArgsOperation has details from the auth provider - assert new_op.device_id == op.auth_provider.device_id - assert new_op.module_id == op.auth_provider.module_id - assert new_op.hostname == op.auth_provider.hostname - assert new_op.client_cert is op.auth_provider.get_x509_certificate.return_value - assert new_op.sas_token is None - if all_auth_args: - assert new_op.server_verification_cert == op.auth_provider.server_verification_cert - assert new_op.gateway_hostname == op.auth_provider.gateway_hostname - else: - assert new_op.server_verification_cert is None - assert new_op.gateway_hostname is None - - @pytest.mark.it( - "Completes the original operation upon completion of the SetIoTHubConnectionArgsOperation" - ) - def test_complete_worker(self, op, stage, op_error): - # Run original op - stage.run_op(op) - assert not op.completed - - # A SetIoTHubConnectionArgsOperation op has been sent down the pipeline - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_iothub.SetIoTHubConnectionArgsOperation) - assert not new_op.completed - - # Complete the new op - new_op.complete(error=op_error) - - # Both ops are now completed - assert new_op.completed - assert new_op.error is op_error - assert op.completed - assert op.error is op_error - - -@pytest.mark.describe("UseAuthProviderStage - .run_op() -- Called with arbitrary other operation") -class TestUseAuthProviderStageRunOpWithAribitraryOperation( - StageRunOpTestBase, UseAuthProviderStageTestConfig -): - @pytest.fixture - def op(self, arbitrary_op): - return arbitrary_op - - @pytest.mark.it("Sends the operation down the pipeline") - def test_sends_down(self, mocker, stage, op): - stage.run_op(op) - - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(op) - assert not op.completed - - -@pytest.mark.describe( - "UseAuthProviderStage - OCCURANCE: SAS Authentication Provider updates SAS token" -) -class TestUseAuthProviderStageWhenAuthProviderGeneratesNewSasToken(UseAuthProviderStageTestConfig): - # Auth Providers are configured with different values depending on if the higher level client - # is a Device or Module. Parametrize with both possibilities. - # TODO: Eventually would be ideal to test using real auth provider instead of the fake one - # This probably should just wait until auth provider refactor for ease though. - @pytest.fixture(params=["Device", "Module"]) - def fake_auth_provider(self, request, mocker): - class FakeAuthProvider(AuthenticationProvider): - pass - - if request.param == "Device": - fake_auth_provider = FakeAuthProvider(hostname=fake_hostname, device_id=fake_device_id) - else: - fake_auth_provider = FakeAuthProvider( - hostname=fake_hostname, device_id=fake_device_id, module_id=fake_module_id - ) - fake_auth_provider.get_current_sas_token = mocker.MagicMock() - fake_auth_provider.on_sas_token_updated_handler_list = [mocker.MagicMock()] - return fake_auth_provider - - @pytest.fixture - def stage(self, mocker, init_kwargs, fake_auth_provider): - stage = pipeline_stages_iothub.UseAuthProviderStage(**init_kwargs) - stage.send_op_down = mocker.MagicMock() - stage.send_event_up = mocker.MagicMock() - - # Attach an auth provider - set_auth_op = pipeline_ops_iothub.SetAuthProviderOperation( - auth_provider=fake_auth_provider, callback=mocker.MagicMock() - ) - stage.run_op(set_auth_op) - assert stage.auth_provider is fake_auth_provider - stage.send_op_down.reset_mock() - stage.send_event_up.reset_mock() - return stage - - @pytest.mark.it("Sends an UpdateSasTokenOperation with the new SAS token down the pipeline") - def test_generates_new_token(self, mocker, stage): - for x in stage.auth_provider.on_sas_token_updated_handler_list: - x() - - assert stage.send_op_down.call_count == 1 - op = stage.send_op_down.call_args[0][0] - assert isinstance(op, pipeline_ops_base.UpdateSasTokenOperation) - assert op.sas_token is stage.auth_provider.get_current_sas_token.return_value - - @pytest.mark.it( - "Sends the error to the background exception handler, if the UpdateSasTokenOperation is completed with error" - ) - def test_update_fails( - self, mocker, stage, arbitrary_exception, mock_handle_background_exception - ): - for x in stage.auth_provider.on_sas_token_updated_handler_list: - x() - - assert stage.send_op_down.call_count == 1 - op = stage.send_op_down.call_args[0][0] - - assert mock_handle_background_exception.call_count == 0 - - op.complete(error=arbitrary_exception) - assert mock_handle_background_exception.call_count == 1 - assert mock_handle_background_exception.call_args == mocker.call(arbitrary_exception) - - ######################################### # ENSURE DESIRED PROPERTIES STAGE STAGE # ######################################### @@ -799,10 +472,16 @@ def test_request_and_response_op(self, mocker, stage, op): class TestTwinRequestResponseStageRunOpWithPatchTwinReportedPropertiesOperation( StageRunOpTestBase, TwinRequestResponseStageTestConfig ): - # CT-TODO: parametrize this with realistic json objects - @pytest.fixture - def json_patch(self): - return {"json_key": "json_val"} + @pytest.fixture(params=["Dictionary Patch", "String Patch", "Integer Patch", "None Patch"]) + def json_patch(self, request): + if request.param == "Dictionary Patch": + return {"json_key": "json_val"} + elif request.param == "String Patch": + return "some_json" + elif request.param == "Integer Patch": + return 1234 + elif request.param == "None Patch": + return None @pytest.fixture def op(self, mocker, json_patch): diff --git a/azure-iot-device/tests/iothub/pipeline/test_pipeline_stages_iothub_http.py b/azure-iot-device/tests/iothub/pipeline/test_pipeline_stages_iothub_http.py index 5f8b9e41e..cece80929 100644 --- a/azure-iot-device/tests/iothub/pipeline/test_pipeline_stages_iothub_http.py +++ b/azure-iot-device/tests/iothub/pipeline/test_pipeline_stages_iothub_http.py @@ -19,7 +19,7 @@ from tests.common.pipeline.helpers import StageRunOpTestBase from tests.common.pipeline import pipeline_stage_test from azure.iot.device import constant as pkg_constant -from azure.iot.device.product_info import ProductInfo +from azure.iot.device import user_agent logging.basicConfig(level=logging.DEBUG) pytestmark = pytest.mark.usefixtures("fake_pipeline_thread") @@ -61,163 +61,50 @@ def init_kwargs(self): return {} @pytest.fixture - def stage(self, mocker, cls_type, init_kwargs): + def pipeline_config(self, mocker): + # auth type shouldn't matter for this stage, so just give it a fake sastoken for now. + # Manually override to make this for modules + cfg = config.IoTHubPipelineConfig( + hostname="http://my.hostname", device_id="my_device", sastoken=mocker.MagicMock() + ) + return cfg + + @pytest.fixture + def stage(self, mocker, cls_type, init_kwargs, pipeline_config): stage = cls_type(**init_kwargs) + stage.pipeline_root = stage.pipeline_root = pipeline_stages_base.PipelineRootStage( + pipeline_config + ) stage.send_op_down = mocker.MagicMock() stage.send_event_up = mocker.MagicMock() return stage -class IoTHubHTTPTranslationStageInstantiationTests(IoTHubHTTPTranslationStageTestConfig): - @pytest.mark.it("Initializes 'device_id' as None") - def test_device_id(self, init_kwargs): - stage = pipeline_stages_iothub_http.IoTHubHTTPTranslationStage(**init_kwargs) - assert stage.device_id is None - - @pytest.mark.it("Initializes 'module_id' as None") - def test_module_id(self, init_kwargs): - stage = pipeline_stages_iothub_http.IoTHubHTTPTranslationStage(**init_kwargs) - assert stage.module_id is None - - @pytest.mark.it("Initializes 'hostname' as None") - def test_hostname(self, init_kwargs): - stage = pipeline_stages_iothub_http.IoTHubHTTPTranslationStage(**init_kwargs) - assert stage.hostname is None - - pipeline_stage_test.add_base_pipeline_stage_tests( test_module=this_module, stage_class_under_test=pipeline_stages_iothub_http.IoTHubHTTPTranslationStage, stage_test_config_class=IoTHubHTTPTranslationStageTestConfig, - extended_stage_instantiation_test_class=IoTHubHTTPTranslationStageInstantiationTests, ) -@pytest.mark.describe( - "IoTHubHTTPTranslationStage - .run_op() -- Called with SetIoTHubConnectionArgsOperation op" -) -class TestIoTHubHTTPTranslationStageRunOpCalledWithConnectionArgsOperation( - IoTHubHTTPTranslationStageTestConfig, StageRunOpTestBase -): - @pytest.fixture(params=["SAS", "X509"]) - def auth_type(self, request): - return request.param - - @pytest.fixture(params=[True, False], ids=["w/ GatewayHostName", "No GatewayHostName"]) - def use_gateway_hostname(self, request): - return request.param - - @pytest.fixture( - params=[True, False], ids=["w/ server verification cert", "No server verification cert"] - ) - def use_server_verification_cert(self, request): - return request.param - - @pytest.fixture(params=["Device", "Module"]) - def op(self, mocker, request, auth_type, use_gateway_hostname, use_server_verification_cert): - kwargs = { - "device_id": "fake_device_id", - "hostname": "fake_hostname", - "callback": mocker.MagicMock(), - } - if request.param == "Module": - kwargs["module_id"] = "fake_module_id" - - if auth_type == "SAS": - kwargs["sas_token"] = "fake_sas_token" - else: - kwargs["client_cert"] = mocker.MagicMock() # representing X509 obj - - if use_gateway_hostname: - kwargs["gateway_hostname"] = "fake_gateway_hostname" - - if use_server_verification_cert: - kwargs["server_verification_cert"] = "fake_server_verification_cert" - - return pipeline_ops_iothub.SetIoTHubConnectionArgsOperation(**kwargs) - - @pytest.mark.it( - "Sets the 'device_id' and 'module_id' values from the op as the stage's 'device_id' and 'module_id' attributes" - ) - def test_cache_device_id_and_module_id(self, stage, op): - assert stage.device_id is None - assert stage.module_id is None - - stage.run_op(op) - - assert stage.device_id == op.device_id - assert stage.module_id == op.module_id - - @pytest.mark.it( - "Sets the 'gateway_hostname' value from the op as the stage's 'hostname' attribute if one is provided, otherwise, use the op's 'hostname'" - ) - def test_cache_hostname(self, stage, op): - assert stage.hostname is None - stage.run_op(op) - - if op.gateway_hostname is not None: - assert stage.hostname == op.gateway_hostname - assert stage.hostname != op.hostname - else: - assert stage.hostname == op.hostname - assert stage.hostname != op.gateway_hostname - - @pytest.mark.it( - "Sends a new SetHTTPConnectionArgsOperation op down the pipeline, configured based on the settings of the SetIoTHubConnectionArgsOperation" - ) - def test_sends_op_down(self, mocker, stage, op): - stage.run_op(op) - - # Op was sent down - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance(new_op, pipeline_ops_http.SetHTTPConnectionArgsOperation) - - # Validate contents of the op - assert new_op.hostname == stage.hostname - assert new_op.server_verification_cert == op.server_verification_cert - assert new_op.client_cert == op.client_cert - assert new_op.sas_token == op.sas_token - - @pytest.mark.it( - "Completes the original SetIoTHubConnectionArgsOperation (with the same error, or lack thereof) if the new SetHTTPConnectionArgsOperation is completed later on" - ) - def test_completing_new_op_completes_original(self, mocker, stage, op_error, op): - stage.run_op(op) - assert stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - - assert not op.completed - assert not new_op.completed - - new_op.complete(error=op_error) - - assert new_op.completed - assert new_op.error is op_error - assert op.completed - assert op.error is op_error - - @pytest.mark.describe( "IoTHubHTTPTranslationStage - .run_op() -- Called with MethodInvokeOperation op" ) class TestIoTHubHTTPTranslationStageRunOpCalledWithMethodInvokeOperation( IoTHubHTTPTranslationStageTestConfig, StageRunOpTestBase ): - # Because Storage/Blob related functionality is limited to Module, configure the stage for a module @pytest.fixture - def stage(self, mocker, cls_type, init_kwargs): - stage = cls_type(**init_kwargs) - pl_config = config.IoTHubPipelineConfig() - stage.pipeline_root = pipeline_stages_base.PipelineRootStage( - pipeline_configuration=pl_config + def pipeline_config(self, mocker): + # Because Method related functionality is limited to Module, configure the stage for a module + # auth type shouldn't matter for this stage, so just give it a fake sastoken for now. + cfg = config.IoTHubPipelineConfig( + hostname="http://my.hostname", + gateway_hostname="http://my.gateway.hostname", + device_id="my_device", + module_id="my_module", + sastoken=mocker.MagicMock(), ) - stage.send_op_down = mocker.MagicMock() - stage.send_event_up = mocker.MagicMock() - stage.device_id = "fake_device_id" - stage.module_id = "fake_module_id" - stage.hostname = "fake_hostname" - return stage + return cfg @pytest.fixture(params=["Targeting Device Method", "Targeting Module Method"]) def op(self, mocker, request): @@ -282,7 +169,7 @@ def test_sends_get_storage_request(self, mocker, stage, op, mock_http_path_iothu pytest.param(12345, id="Non-string custom user agent"), ], ) - def test_new_op_headers(self, mocker, stage, op, custom_user_agent): + def test_new_op_headers(self, mocker, stage, op, custom_user_agent, pipeline_config): stage.pipeline_root.pipeline_configuration.product_info = custom_user_agent stage.run_op(op) @@ -293,11 +180,11 @@ def test_new_op_headers(self, mocker, stage, op, custom_user_agent): # Validate headers expected_user_agent = urllib.parse.quote_plus( - ProductInfo.get_iothub_user_agent() + str(custom_user_agent) + user_agent.get_iothub_user_agent() + str(custom_user_agent) ) - expected_edge_string = "{}/{}".format(stage.device_id, stage.module_id) + expected_edge_string = "{}/{}".format(pipeline_config.device_id, pipeline_config.module_id) - assert new_op.headers["Host"] == stage.hostname + assert new_op.headers["Host"] == pipeline_config.gateway_hostname assert new_op.headers["Content-Type"] == "application/json" assert new_op.headers["Content-Length"] == len(new_op.body) assert new_op.headers["x-ms-edge-moduleId"] == expected_edge_string @@ -450,21 +337,14 @@ def test_new_op_completes_with_error(self, mocker, stage, op, arbitrary_exceptio class TestIoTHubHTTPTranslationStageRunOpCalledWithGetStorageInfoOperation( IoTHubHTTPTranslationStageTestConfig, StageRunOpTestBase ): - - # Because Storage/Blob related functionality is limited to Devices, configure the stage for a device @pytest.fixture - def stage(self, mocker, cls_type, init_kwargs): - stage = cls_type(**init_kwargs) - pl_config = config.IoTHubPipelineConfig() - stage.pipeline_root = pipeline_stages_base.PipelineRootStage( - pipeline_configuration=pl_config + def pipeline_config(self, mocker): + # Because Storage/Blob related functionality is limited to Device, configure pipeline for a device + # auth type shouldn't matter for this stage, so just give it a fake sastoken for now. + cfg = config.IoTHubPipelineConfig( + hostname="http://my.hostname", device_id="my_device", sastoken=mocker.MagicMock() ) - stage.send_op_down = mocker.MagicMock() - stage.send_event_up = mocker.MagicMock() - stage.device_id = "fake_device_id" - stage.module_id = None - stage.hostname = "fake_hostname" - return stage + return cfg @pytest.fixture def op(self, mocker): @@ -484,7 +364,9 @@ def test_sends_op_down(self, mocker, stage, op): @pytest.mark.it( "Configures the HTTPRequestAndResponseOperation with request details for sending a Get Storage Info request" ) - def test_sends_get_storage_request(self, mocker, stage, op, mock_http_path_iothub): + def test_sends_get_storage_request( + self, mocker, stage, op, mock_http_path_iothub, pipeline_config + ): stage.run_op(op) # Op was sent down @@ -495,7 +377,7 @@ def test_sends_get_storage_request(self, mocker, stage, op, mock_http_path_iothu # Validate request assert mock_http_path_iothub.get_storage_info_for_blob_path.call_count == 1 assert mock_http_path_iothub.get_storage_info_for_blob_path.call_args == mocker.call( - stage.device_id + pipeline_config.device_id ) expected_path = mock_http_path_iothub.get_storage_info_for_blob_path.return_value @@ -517,7 +399,7 @@ def test_sends_get_storage_request(self, mocker, stage, op, mock_http_path_iothu pytest.param(12345, id="Non-string custom user agent"), ], ) - def test_new_op_headers(self, mocker, stage, op, custom_user_agent): + def test_new_op_headers(self, mocker, stage, op, custom_user_agent, pipeline_config): stage.pipeline_root.pipeline_configuration.product_info = custom_user_agent stage.run_op(op) @@ -528,10 +410,10 @@ def test_new_op_headers(self, mocker, stage, op, custom_user_agent): # Validate headers expected_user_agent = urllib.parse.quote_plus( - ProductInfo.get_iothub_user_agent() + str(custom_user_agent) + user_agent.get_iothub_user_agent() + str(custom_user_agent) ) - assert new_op.headers["Host"] == stage.hostname + assert new_op.headers["Host"] == pipeline_config.hostname assert new_op.headers["Accept"] == "application/json" assert new_op.headers["Content-Type"] == "application/json" assert new_op.headers["Content-Length"] == len(new_op.body) @@ -681,21 +563,14 @@ def test_new_op_completes_with_error(self, mocker, stage, op, arbitrary_exceptio class TestIoTHubHTTPTranslationStageRunOpCalledWithNotifyBlobUploadStatusOperation( IoTHubHTTPTranslationStageTestConfig, StageRunOpTestBase ): - - # Because Storage/Blob related functionality is limited to Devices, configure the stage for a device @pytest.fixture - def stage(self, mocker, cls_type, init_kwargs): - stage = cls_type(**init_kwargs) - pl_config = config.IoTHubPipelineConfig() - stage.pipeline_root = pipeline_stages_base.PipelineRootStage( - pipeline_configuration=pl_config + def pipeline_config(self, mocker): + # Because Storage/Blob related functionality is limited to Device, configure pipeline for a device + # auth type shouldn't matter for this stage, so just give it a fake sastoken for now. + cfg = config.IoTHubPipelineConfig( + hostname="http://my.hostname", device_id="my_device", sastoken=mocker.MagicMock() ) - stage.send_op_down = mocker.MagicMock() - stage.send_event_up = mocker.MagicMock() - stage.device_id = "fake_device_id" - stage.module_id = None - stage.hostname = "fake_hostname" - return stage + return cfg @pytest.fixture def op(self, mocker): @@ -719,7 +594,9 @@ def test_sends_op_down(self, mocker, stage, op): @pytest.mark.it( "Configures the HTTPRequestAndResponseOperation with request details for sending a Notify Blob Upload Status request" ) - def test_sends_get_storage_request(self, mocker, stage, op, mock_http_path_iothub): + def test_sends_get_storage_request( + self, mocker, stage, op, mock_http_path_iothub, pipeline_config + ): stage.run_op(op) # Op was sent down @@ -730,7 +607,7 @@ def test_sends_get_storage_request(self, mocker, stage, op, mock_http_path_iothu # Validate request assert mock_http_path_iothub.get_notify_blob_upload_status_path.call_count == 1 assert mock_http_path_iothub.get_notify_blob_upload_status_path.call_args == mocker.call( - stage.device_id + pipeline_config.device_id ) expected_path = mock_http_path_iothub.get_notify_blob_upload_status_path.return_value @@ -752,7 +629,7 @@ def test_sends_get_storage_request(self, mocker, stage, op, mock_http_path_iothu pytest.param(12345, id="Non-string custom user agent"), ], ) - def test_new_op_headers(self, mocker, stage, op, custom_user_agent): + def test_new_op_headers(self, mocker, stage, op, custom_user_agent, pipeline_config): stage.pipeline_root.pipeline_configuration.product_info = custom_user_agent stage.run_op(op) @@ -763,10 +640,10 @@ def test_new_op_headers(self, mocker, stage, op, custom_user_agent): # Validate headers expected_user_agent = urllib.parse.quote_plus( - ProductInfo.get_iothub_user_agent() + str(custom_user_agent) + user_agent.get_iothub_user_agent() + str(custom_user_agent) ) - assert new_op.headers["Host"] == stage.hostname + assert new_op.headers["Host"] == pipeline_config.hostname assert new_op.headers["Content-Type"] == "application/json; charset=utf-8" assert new_op.headers["Content-Length"] == len(new_op.body) assert new_op.headers["User-Agent"] == expected_user_agent diff --git a/azure-iot-device/tests/iothub/pipeline/test_pipeline_stages_iothub_mqtt.py b/azure-iot-device/tests/iothub/pipeline/test_pipeline_stages_iothub_mqtt.py index 4bb578a51..1367ef1bd 100644 --- a/azure-iot-device/tests/iothub/pipeline/test_pipeline_stages_iothub_mqtt.py +++ b/azure-iot-device/tests/iothub/pipeline/test_pipeline_stages_iothub_mqtt.py @@ -21,1497 +21,938 @@ pipeline_ops_iothub, pipeline_stages_iothub_mqtt, config, + mqtt_topic_iothub, ) from azure.iot.device.iothub.pipeline.exceptions import OperationError, PipelineError from azure.iot.device.iothub.models.message import Message from azure.iot.device.iothub.models.methods import MethodRequest, MethodResponse -from tests.common.pipeline.helpers import all_common_ops, all_common_events, StageTestBase -from tests.iothub.pipeline.helpers import all_iothub_ops, all_iothub_events +from tests.common.pipeline.helpers import StageRunOpTestBase, StageHandlePipelineEventTestBase from tests.common.pipeline import pipeline_stage_test -from azure.iot.device import constant as pkg_constant -from azure.iot.device.product_info import ProductInfo +from azure.iot.device import constant as pkg_constant, user_agent logging.basicConfig(level=logging.DEBUG) - this_module = sys.modules[__name__] - - -# This fixture makes it look like all test in this file tests are running -# inside the pipeline thread. Because this is an autouse fixture, we -# manually add it to the individual test.py files that need it. If, -# instead, we had added it to some conftest.py, it would be applied to -# every tests in every file and we don't want that. -@pytest.fixture(autouse=True) -def apply_fake_pipeline_thread(fake_pipeline_thread): - pass - - -fake_device_id = "__fake_device_id__" -fake_module_id = "__fake_module_id__" -fake_hostname = "__fake_hostname__" -fake_gateway_hostname = "__fake_gateway_hostname__" -fake_server_verification_cert = "__fake_server_verification_cert__" -fake_client_cert = "__fake_client_cert__" -fake_sas_token = "__fake_sas_token__" - -fake_message_id = "ee9e738b-4f47-447a-9892-5b1d1d7ca5" -fake_message_id_encoded = "%24.mid=ee9e738b-4f47-447a-9892-5b1d1d7ca5" -fake_message_body = "__fake_message_body__" -fake_output_name = "__fake_output_name__" -fake_output_name_encoded = "%24.on=__fake_output_name__" -fake_content_type = "text/json" -fake_content_type_encoded = "%24.ct=text%2Fjson" -fake_content_encoding = "utf-16" -fake_content_encoding_encoded = "%24.ce=utf-16" -default_content_type = "application/json" -default_content_type_encoded = "%24.ct=application%2Fjson" -default_content_encoding_encoded = "%24.ce=utf-8" -fake_message = Message(fake_message_body) -security_message_interface_id_encoded = "%24.ifid=urn%3Aazureiot%3ASecurity%3ASecurityAgent%3A1" -fake_request_id = "__fake_request_id__" -fake_method_name = "__fake_method_name__" -fake_method_payload = "__fake_method_payload__" -fake_method_status = "__fake_method_status__" -fake_method_response = MethodResponse( - request_id=fake_request_id, status=fake_method_status, payload=fake_method_payload -) - -invalid_feature_name = "__invalid_feature_name__" -unmatched_mqtt_topic = "__unmatched_mqtt_topic__" -fake_mqtt_payload = "__fake_mqtt_payload__" - -fake_c2d_topic = "devices/{}/messages/devicebound/".format(fake_device_id) -fake_c2d_topic_with_content_type = "{}{}".format(fake_c2d_topic, fake_content_type_encoded) -fake_c2d_topic_for_another_device = "devices/__other_device__/messages/devicebound/" - -fake_input_name = "__fake_input_name__" -fake_input_message_topic = "devices/{}/modules/{}/inputs/{}/".format( - fake_device_id, fake_module_id, fake_input_name -) -fake_input_message_topic_with_content_type = "{}{}".format( - fake_input_message_topic, fake_content_type_encoded -) -fake_input_message_topic_for_another_module = "devices/{}/modules/__other_module__/messages/devicebound/".format( - fake_device_id -) -fake_input_message_topic_for_another_device = "devices/__other_device__/modules/{}/messages/devicebound/".format( - fake_module_id -) - -fake_method_request_topic = "$iothub/methods/POST/{}/?$rid={}".format( - fake_method_name, fake_request_id -) -fake_method_request_payload = "{}".encode("utf-8") - -encoded_user_agent = urllib.parse.quote(ProductInfo.get_iothub_user_agent(), safe="") - -fake_message_user_property_1_key = "is-muggle" -fake_message_user_property_1_value = "yes" -fake_message_user_property_2_key = "sorted-house" -fake_message_user_property_2_value = "hufflepuff" -fake_message_user_property_1_encoded = "is-muggle=yes" -fake_message_user_property_2_encoded = "sorted-house=hufflepuff" - -ops_handled_by_this_stage = [ - pipeline_ops_iothub.SetIoTHubConnectionArgsOperation, - pipeline_ops_iothub.SendD2CMessageOperation, - pipeline_ops_base.UpdateSasTokenOperation, - pipeline_ops_iothub.SendOutputEventOperation, - pipeline_ops_iothub.SendMethodResponseOperation, - pipeline_ops_base.RequestOperation, - pipeline_ops_base.EnableFeatureOperation, - pipeline_ops_base.DisableFeatureOperation, -] - -events_handled_by_this_stage = [pipeline_events_mqtt.IncomingMQTTMessageEvent] - -pipeline_stage_test.add_base_pipeline_stage_tests_old( - cls=pipeline_stages_iothub_mqtt.IoTHubMQTTTranslationStage, - module=this_module, - all_ops=all_common_ops + all_iothub_ops, - handled_ops=ops_handled_by_this_stage, - all_events=all_common_events + all_iothub_events, - handled_events=events_handled_by_this_stage, - extra_initializer_defaults={"feature_to_topic": dict}, -) - - -def create_message_with_user_properties(message_content, is_multiple): - m = Message(message_content) - m.custom_properties[fake_message_user_property_1_key] = fake_message_user_property_1_value - if is_multiple: - m.custom_properties[fake_message_user_property_2_key] = fake_message_user_property_2_value - return m - - -def create_security_message(message_content): - msg = Message(message_content) - msg.set_as_security_message() - return msg - - -def create_message_with_system_and_user_properties(message_content, is_multiple): - if is_multiple: - msg = Message(message_content, message_id=fake_message_id, output_name=fake_output_name) - else: - msg = Message(message_content, message_id=fake_message_id) - - msg.custom_properties[fake_message_user_property_1_key] = fake_message_user_property_1_value - if is_multiple: - msg.custom_properties[fake_message_user_property_2_key] = fake_message_user_property_2_value - return msg - - -def create_security_message_with_system_and_user_properties(message_content, is_multiple): - if is_multiple: - msg = Message(message_content, message_id=fake_message_id, output_name=fake_output_name) - else: - msg = Message(message_content, message_id=fake_message_id) - - msg.custom_properties[fake_message_user_property_1_key] = fake_message_user_property_1_value - if is_multiple: - msg.custom_properties[fake_message_user_property_2_key] = fake_message_user_property_2_value - msg.set_as_security_message() - return msg - - -def create_message_for_output_with_user_properties(message_content, is_multiple): - m = Message(message_content, output_name=fake_output_name) - m.custom_properties[fake_message_user_property_1_key] = fake_message_user_property_1_value - if is_multiple: - m.custom_properties[fake_message_user_property_2_key] = fake_message_user_property_2_value - return m - - -def create_message_for_output_with_system_and_user_properties(message_content, is_multiple): - if is_multiple: - msg = Message( - message_content, - output_name=fake_output_name, - message_id=fake_message_id, - content_type=fake_content_type, - ) - else: - msg = Message(message_content, output_name=fake_output_name, message_id=fake_message_id) - - msg.custom_properties[fake_message_user_property_1_key] = fake_message_user_property_1_value - if is_multiple: - msg.custom_properties[fake_message_user_property_2_key] = fake_message_user_property_2_value - return msg +pytestmark = pytest.mark.usefixtures("fake_pipeline_thread", "mock_mqtt_topic") @pytest.fixture -def set_connection_args(mocker): - return pipeline_ops_iothub.SetIoTHubConnectionArgsOperation( - device_id=fake_device_id, hostname=fake_hostname, callback=mocker.MagicMock() - ) +def mock_mqtt_topic(mocker): + # Don't mock the whole module, just mock what we want to (which is most of it). + # Mocking out the get_x_topic style functions is useful, but the ones that + # match patterns and return bools (is_x_topic) making testing annoying if mocked. + mocker.patch.object(mqtt_topic_iothub, "get_telemetry_topic_for_publish") + mocker.patch.object(mqtt_topic_iothub, "get_method_topic_for_publish") + mocker.patch.object(mqtt_topic_iothub, "get_twin_topic_for_publish") + mocker.patch.object(mqtt_topic_iothub, "get_c2d_topic_for_subscribe") + mocker.patch.object(mqtt_topic_iothub, "get_input_topic_for_subscribe") + mocker.patch.object(mqtt_topic_iothub, "get_method_topic_for_subscribe") + mocker.patch.object(mqtt_topic_iothub, "get_twin_response_topic_for_subscribe") + mocker.patch.object(mqtt_topic_iothub, "get_twin_patch_topic_for_subscribe") + mocker.patch.object(mqtt_topic_iothub, "encode_message_properties_in_topic") + mocker.patch.object(mqtt_topic_iothub, "extract_message_properties_from_topic") + # It's kind of weird that we return the (unmocked) module, but it's easier this way, + # and since it's a module, not a function, we'd never treat it like a mock anyway + # (you don't check the call count of a module) + return mqtt_topic_iothub + + +@pytest.fixture(params=[True, False], ids=["With error", "No error"]) +def op_error(request, arbitrary_exception): + if request.param: + return arbitrary_exception + else: + return None +# NOTE: This fixutre is defined out here rather than on a class because it is used for both +# EnableFeatureOperation and DisableFeatureOperation tests @pytest.fixture -def set_connection_args_for_device(set_connection_args): - return set_connection_args +def expected_mqtt_topic_fn(mock_mqtt_topic, iothub_pipeline_feature): + if iothub_pipeline_feature == constant.C2D_MSG: + return mock_mqtt_topic.get_c2d_topic_for_subscribe + elif iothub_pipeline_feature == constant.INPUT_MSG: + return mock_mqtt_topic.get_input_topic_for_subscribe + elif iothub_pipeline_feature == constant.METHODS: + return mock_mqtt_topic.get_method_topic_for_subscribe + elif iothub_pipeline_feature == constant.TWIN: + return mock_mqtt_topic.get_twin_response_topic_for_subscribe + elif iothub_pipeline_feature == constant.TWIN_PATCHES: + return mock_mqtt_topic.get_twin_patch_topic_for_subscribe + else: + # This shouldn't happen + assert False +# NOTE: This fixutre is defined out here rather than on a class because it is used for both +# EnableFeatureOperation and DisableFeatureOperation tests @pytest.fixture -def set_connection_args_for_module(set_connection_args): - set_connection_args.module_id = fake_module_id - return set_connection_args - - -class IoTHubMQTTTranslationStageTestBase(StageTestBase): - @pytest.fixture(autouse=True) - def stage_base_configuration(self, stage, mocker): - class NextStageForTest(pipeline_stages_base.PipelineStage): - def _run_op(self, op): - pass - - next = NextStageForTest() - root = ( - pipeline_stages_base.PipelineRootStage(config.IoTHubPipelineConfig()) - .append_stage(stage) - .append_stage(next) +def expected_mqtt_topic_fn_call(mocker, iothub_pipeline_feature, stage): + if iothub_pipeline_feature == constant.C2D_MSG: + return mocker.call(stage.pipeline_root.pipeline_configuration.device_id) + elif iothub_pipeline_feature == constant.INPUT_MSG: + return mocker.call( + stage.pipeline_root.pipeline_configuration.device_id, + stage.pipeline_root.pipeline_configuration.module_id, ) + else: + return mocker.call() - mocker.spy(stage, "_run_op") - mocker.spy(stage, "run_op") - - mocker.spy(next, "_run_op") - mocker.spy(next, "run_op") - return root +class IoTHubMQTTTranslationStageTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_stages_iothub_mqtt.IoTHubMQTTTranslationStage @pytest.fixture - def stage(self, mocker): - stage = pipeline_stages_iothub_mqtt.IoTHubMQTTTranslationStage() - mocker.spy(stage, "send_op_down") - return stage + def init_kwargs(self): + return {} @pytest.fixture - def stage_configured_for_device( - self, stage, stage_base_configuration, set_connection_args_for_device, mocker - ): - set_connection_args_for_device.callback = None - stage.run_op(set_connection_args_for_device) - mocker.resetall() + def pipeline_config(self, mocker): + # NOTE 1: auth type shouldn't matter for this stage, so just give it a fake sastoken for now. + # NOTE 2: This config is configured for a device, not a module. Where relevant, override this + # fixture or dynamically add a module_id + cfg = config.IoTHubPipelineConfig( + hostname="http://my.hostname", device_id="my_device", sastoken=mocker.MagicMock() + ) + return cfg @pytest.fixture - def stage_configured_for_module( - self, stage, stage_base_configuration, set_connection_args_for_module, mocker - ): - set_connection_args_for_module.callback = None - stage.run_op(set_connection_args_for_module) - mocker.resetall() + def stage(self, mocker, cls_type, init_kwargs, pipeline_config): + stage = cls_type(**init_kwargs) + stage.pipeline_root = pipeline_stages_base.PipelineRootStage(pipeline_config) + stage.send_op_down = mocker.MagicMock() + stage.send_event_up = mocker.MagicMock() + return stage - @pytest.fixture(params=["device", "module"]) - def stages_configured_for_both( - self, request, stage, stage_base_configuration, set_connection_args, mocker - ): - set_connection_args.callback = None - if request.param == "module": - set_connection_args.module_id = fake_module_id - stage.run_op(set_connection_args) - mocker.resetall() + +pipeline_stage_test.add_base_pipeline_stage_tests( + test_module=this_module, + stage_class_under_test=pipeline_stages_iothub_mqtt.IoTHubMQTTTranslationStage, + stage_test_config_class=IoTHubMQTTTranslationStageTestConfig, +) @pytest.mark.describe( - "IoTHubMQTTTranslationStage - .run_op() -- called with SetIoTHubConnectionArgsOperation" + "IoTHubMQTTTranslationStage - .run_op() -- Called with InitializePipelineOperation (Pipeline has Device Configuration)" ) -class TestIoTHubMQTTConverterWithSetAuthProviderArgs(IoTHubMQTTTranslationStageTestBase): - @pytest.mark.it( - "Runs a pipeline_ops_mqtt.SetMQTTConnectionArgsOperation worker operation on the next stage" +class TestIoTHubMQTTTranslationStageRunOpWithInitializePipelineOperationOnDevice( + StageRunOpTestBase, IoTHubMQTTTranslationStageTestConfig +): + @pytest.fixture + def op(self, mocker): + return pipeline_ops_base.InitializePipelineOperation(callback=mocker.MagicMock()) + + @pytest.mark.it("Derives the MQTT client id, and sets it on the op") + def test_client_id(self, stage, op, pipeline_config): + assert not hasattr(op, "client_id") + stage.run_op(op) + + assert op.client_id == pipeline_config.device_id + + @pytest.mark.it("Derives the MQTT username, and sets it on the op") + @pytest.mark.parametrize( + "cust_product_info", + [ + pytest.param("", id="No custom product info"), + pytest.param("my-product-info", id="With custom product info"), + pytest.param("my$product$info", id="With custom product info (URL encoding required)"), + ], ) - def test_runs_set_connection_args(self, mocker, stage, set_connection_args): - set_connection_args.spawn_worker_op = mocker.MagicMock() - stage.run_op(set_connection_args) - assert set_connection_args.spawn_worker_op.call_count == 1 - assert ( - set_connection_args.spawn_worker_op.call_args[1]["worker_op_type"] - is pipeline_ops_mqtt.SetMQTTConnectionArgsOperation + def test_username(self, stage, op, pipeline_config, cust_product_info): + pipeline_config.product_info = cust_product_info + assert not hasattr(op, "username") + stage.run_op(op) + + expected_username = "{hostname}/{client_id}/?api-version={api_version}&DeviceClientType={user_agent}{custom_product_info}".format( + hostname=pipeline_config.hostname, + client_id=pipeline_config.device_id, + api_version=pkg_constant.IOTHUB_API_VERSION, + user_agent=urllib.parse.quote(user_agent.get_iothub_user_agent(), safe=""), + custom_product_info=urllib.parse.quote(pipeline_config.product_info, safe=""), ) - worker = set_connection_args.spawn_worker_op.return_value - assert stage.send_op_down.call_count == 1 - assert stage.send_op_down.call_args == mocker.call(worker) + assert op.username == expected_username @pytest.mark.it( - "Sets connection_args.client_id to auth_provider_args.device_id if auth_provider_args.module_id is None" - ) - def test_sets_client_id_for_devices(self, stage, set_connection_args): - stage.run_op(set_connection_args) - new_op = stage.next._run_op.call_args[0][0] - assert new_op.client_id == fake_device_id + "ALWAYS uses the pipeline configuration's hostname in the MQTT username and NEVER the gateway_hostname" + ) + def test_hostname_vs_gateway_hostname(self, stage, op, pipeline_config): + # NOTE: this is a sanity check test. There's no reason it should ever be using + # gateway hostname rather than hostname, but these are easily confused fields, so + # this test has been included to catch any possible errors down the road + pipeline_config.hostname = "http://my.hostname" + pipeline_config.gateway_hostname = "http://my.gateway.hostname" + stage.run_op(op) - @pytest.mark.it( - "Sets connection_args.client_id to auth_provider_args.device_id/auth_provider_args.module_id if auth_provider_args.module_id is not None" - ) - def test_sets_client_id_for_modules(self, stage, set_connection_args_for_module): - stage.run_op(set_connection_args_for_module) - new_op = stage.next._run_op.call_args[0][0] - assert new_op.client_id == "{}/{}".format(fake_device_id, fake_module_id) + assert pipeline_config.hostname in op.username + assert pipeline_config.gateway_hostname not in op.username - @pytest.mark.it( - "Sets connection_args.hostname to auth_provider.hostname if auth_provider.gateway_hostname is None" - ) - def test_sets_hostname_if_no_gateway(self, stage, set_connection_args): - stage.run_op(set_connection_args) - new_op = stage.next._run_op.call_args[0][0] - assert new_op.hostname == fake_hostname + @pytest.mark.it("Sends the op down the pipeline") + def test_sends_down(self, mocker, stage, op): + stage.run_op(op) + assert stage.send_op_down.call_count == 1 + assert stage.send_op_down.call_args == mocker.call(op) - @pytest.mark.it( - "Sets connection_args.hostname to auth_provider.gateway_hostname if auth_provider.gateway_hostname is not None" - ) - def test_sets_hostname_if_yes_gateway(self, stage, set_connection_args): - set_connection_args.gateway_hostname = fake_gateway_hostname - stage.run_op(set_connection_args) - new_op = stage.next._run_op.call_args[0][0] - assert new_op.hostname == fake_gateway_hostname - @pytest.mark.it( - "Sets connection_args.username to auth_provider.hostname/auth_provider/device_id/?api-version={api_version}&DeviceClientType={user_agent} if auth_provider_args.gateway_hostname is None and module_id is None" - ) - def test_sets_device_username_if_no_gateway(self, stage, set_connection_args): - stage.run_op(set_connection_args) - new_op = stage.next._run_op.call_args[0][0] - assert new_op.username == "{}/{}/?api-version={}&DeviceClientType={}".format( - fake_hostname, fake_device_id, pkg_constant.IOTHUB_API_VERSION, encoded_user_agent +@pytest.mark.describe( + "IoTHubMQTTTranslationStage - .run_op() -- Called with InitializePipelineOperation (Pipeline has Module Configuration)" +) +class TestIoTHubMQTTTranslationStageRunOpWithInitializePipelineOperationOnModule( + StageRunOpTestBase, IoTHubMQTTTranslationStageTestConfig +): + @pytest.fixture + def pipeline_config(self, mocker): + # auth type shouldn't matter for this stage, so just give it a fake sastoken for now. + cfg = config.IoTHubPipelineConfig( + hostname="http://my.hostname", + device_id="my_device", + module_id="my_module", + sastoken=mocker.MagicMock(), ) + return cfg - @pytest.mark.it( - "Sets connection_args.username to auth_provider.hostname/device_id/?api-version={api_version}&DeviceClientType={user_agent} if auth_provider_args.gateway_hostname is not None and module_id is None" - ) - def test_sets_device_username_if_yes_gateway(self, stage, set_connection_args): - set_connection_args.gateway_hostname = fake_gateway_hostname - stage.run_op(set_connection_args) - new_op = stage.next._run_op.call_args[0][0] - assert new_op.username == "{}/{}/?api-version={}&DeviceClientType={}".format( - fake_hostname, fake_device_id, pkg_constant.IOTHUB_API_VERSION, encoded_user_agent - ) + @pytest.fixture + def op(self, mocker): + return pipeline_ops_base.InitializePipelineOperation(callback=mocker.MagicMock()) - @pytest.mark.it( - "Sets connection_args.username to auth_provider.hostname/auth_provider/device_id/?api-version={api_version}&DeviceClientType={user_agent} if auth_provider_args.gateway_hostname is None and module_id is None" - ) - def test_sets_module_username_if_no_gateway(self, stage, set_connection_args_for_module): - stage.run_op(set_connection_args_for_module) - new_op = stage.next._run_op.call_args[0][0] - assert new_op.username == "{}/{}/{}/?api-version={}&DeviceClientType={}".format( - fake_hostname, - fake_device_id, - fake_module_id, - pkg_constant.IOTHUB_API_VERSION, - encoded_user_agent, - ) + @pytest.mark.it("Derives the MQTT client id, and sets it on the op") + def test_client_id(self, stage, op, pipeline_config): + stage.run_op(op) - @pytest.mark.it( - "Sets connection_args.username to auth_provider.hostname/device_id/module_id/?api-version={api_version}&DeviceClientType={user_agent} if auth_provider_args.gateway_hostname is not None and module_id is None" - ) - def test_sets_module_username_if_yes_gateway(self, stage, set_connection_args_for_module): - set_connection_args_for_module.gateway_hostname = fake_gateway_hostname - stage.run_op(set_connection_args_for_module) - new_op = stage.next._run_op.call_args[0][0] - assert new_op.username == "{}/{}/{}/?api-version={}&DeviceClientType={}".format( - fake_hostname, - fake_device_id, - fake_module_id, - pkg_constant.IOTHUB_API_VERSION, - encoded_user_agent, + expected_client_id = "{device_id}/{module_id}".format( + device_id=pipeline_config.device_id, module_id=pipeline_config.module_id ) + assert op.client_id == expected_client_id - @pytest.mark.it( - "Appends product_info to connection_args.username to if self.pipeline_root.pipeline_configuration.product_info is not None" - ) + @pytest.mark.it("Derives the MQTT username, and sets it on the op") @pytest.mark.parametrize( - "fake_product_info, expected_product_info", + "cust_product_info", [ - ("", ""), - ("__fake:product:info__", "__fake%3Aproduct%3Ainfo__"), - (4, 4), - ( - ["fee,fi,fo,fum"], - "%5B%27fee%2Cfi%2Cfo%2Cfum%27%5D", - ), # URI Encoding for str version of list - ( - {"fake_key": "fake_value"}, - "%7B%27fake_key%27%3A%20%27fake_value%27%7D", - ), # URI Encoding for str version of dict + pytest.param("", id="No custom product info"), + pytest.param("my-product-info", id="With custom product info"), + pytest.param("my$product$info", id="With custom product info (URL encoding required)"), ], ) - def test_appends_product_info_to_device_username( - self, stage, set_connection_args, fake_product_info, expected_product_info - ): - set_connection_args.gateway_hostname = fake_gateway_hostname - stage.pipeline_root.pipeline_configuration.product_info = fake_product_info - stage.run_op(set_connection_args) - new_op = stage.next._run_op.call_args[0][0] - assert new_op.username == "{}/{}/?api-version={}&DeviceClientType={}{}".format( - fake_hostname, - fake_device_id, - pkg_constant.IOTHUB_API_VERSION, - encoded_user_agent, - expected_product_info, + def test_username(self, stage, op, pipeline_config, cust_product_info): + pipeline_config.product_info = cust_product_info + stage.run_op(op) + + expected_client_id = "{device_id}/{module_id}".format( + device_id=pipeline_config.device_id, module_id=pipeline_config.module_id + ) + expected_username = "{hostname}/{client_id}/?api-version={api_version}&DeviceClientType={user_agent}{custom_product_info}".format( + hostname=pipeline_config.hostname, + client_id=expected_client_id, + api_version=pkg_constant.IOTHUB_API_VERSION, + user_agent=urllib.parse.quote(user_agent.get_iothub_user_agent(), safe=""), + custom_product_info=urllib.parse.quote(pipeline_config.product_info, safe=""), ) + assert op.username == expected_username @pytest.mark.it( - "Sets connection_args.server_verification_cert to auth_provider.server_verification_cert" - ) - def test_sets_server_verification_cert(self, stage, set_connection_args): - set_connection_args.server_verification_cert = fake_server_verification_cert - stage.run_op(set_connection_args) - new_op = stage.next._run_op.call_args[0][0] - assert new_op.server_verification_cert == fake_server_verification_cert - - @pytest.mark.it("Sets connection_args.client_cert to auth_provider.client_cert") - def test_sets_client_cert(self, stage, set_connection_args): - set_connection_args.client_cert = fake_client_cert - stage.run_op(set_connection_args) - new_op = stage.next._run_op.call_args[0][0] - assert new_op.client_cert == fake_client_cert - - @pytest.mark.it("Sets connection_args.sas_token to auth_provider.sas_token.") - def test_sets_sas_token(self, stage, set_connection_args): - set_connection_args.sas_token = fake_sas_token - stage.run_op(set_connection_args) - new_op = stage.next._run_op.call_args[0][0] - assert new_op.sas_token == fake_sas_token + "ALWAYS uses the pipeline configuration's hostname in the MQTT username and NEVER the gateway_hostname" + ) + def test_hostname_vs_gateway_hostname(self, stage, op, pipeline_config): + # NOTE: this is a sanity check test. There's no reason it should ever be using + # gateway hostname rather than hostname, but these are easily confused fields, so + # this test has been included to catch any possible errors down the road + pipeline_config.hostname = "http://my.hostname" + pipeline_config.gateway_hostname = "http://my.gateway.hostname" + stage.run_op(op) + assert pipeline_config.hostname in op.username + assert pipeline_config.gateway_hostname not in op.username -@pytest.mark.describe( - "IoTHubMQTTTranslationStage - .run_op() -- called with UpdateSasTokenOperation if the transport is disconnected" -) -class TestIoTHubMQTTConverterWithUpdateSasTokenOperationDisconnected( - IoTHubMQTTTranslationStageTestBase -): - @pytest.fixture - def op(self, mocker): - return pipeline_ops_base.UpdateSasTokenOperation( - sas_token=fake_sas_token, callback=mocker.MagicMock() - ) + @pytest.mark.it("Sends the op down the pipeline") + def test_sends_down(self, mocker, stage, op): + stage.run_op(op) + assert stage.send_op_down.call_count == 1 + assert stage.send_op_down.call_args == mocker.call(op) - @pytest.fixture(autouse=True) - def transport_is_disconnected(self, stage): - stage.pipeline_root.connected = False - @pytest.mark.it("Immediately passes the operation to the next stage") - def test_passes_op_immediately(self, stage, op): - stage.run_op(op) - assert stage.next.run_op.call_count == 1 - assert stage.next.run_op.call_args[0][0] == op +# NOTE: All of the following run op tests are tested against a pipeline_config that has been +# configured for a Device Client, not a Module Client. It's worth considering parametrizing +# that fixture so that these tests all run twice - once for a Device, and once for a Module. +# HOWEVER, it's not stricly necessary, due to knowledge of implementation - we are testing that +# the expected values (including module id, which just happens to be set to None when configured +# for a device) are passed where they are expected to be passed. If they're being passed +# correctly, we know it would work no matter what the values are set to. +# +# This also avoids us having module specific tests for device-only features, and vice versa. +# +# In conclusion, while the pipeline_config fixture is technically configured for a device, +# all of the .run_op() tests are written as if it's completely generic. Perhaps this will +# need to change later on. @pytest.mark.describe( - "IoTHubMQTTTranslationStage - .run_op() -- called with UpdateSasTokenOperation if the transport is connected" + "IoTHubMQTTTranslationStage - .run_op() -- Called with SendD2CMessageOperation" ) -class TestIoTHubMQTTConverterWithUpdateSasTokenOperationConnected( - IoTHubMQTTTranslationStageTestBase +class TestIoTHubMQTTTranslationStageRunOpWithSendD2CMessageOperation( + StageRunOpTestBase, IoTHubMQTTTranslationStageTestConfig ): @pytest.fixture def op(self, mocker): - op = pipeline_ops_base.UpdateSasTokenOperation( - sas_token=fake_sas_token, callback=mocker.MagicMock() + return pipeline_ops_iothub.SendD2CMessageOperation( + message=Message("my message"), callback=mocker.MagicMock() ) - mocker.spy(op, "complete") - return op - - @pytest.fixture(autouse=True) - def transport_is_connected(self, stage): - stage.pipeline_root.connected = True - - @pytest.mark.it("Immediately passes the operation to the next stage") - def test_passes_op_immediately(self, stage, op): - stage.run_op(op) - assert stage.next.run_op.call_count == 1 - assert stage.next.run_op.call_args[0][0] == op @pytest.mark.it( - "Passes down a ReauthorizeConnectionOperation instead of completing the op with success after the lower level stage returns success for the UpdateSasTokenOperation" + "Derives the IoTHub telemetry topic from the device/module details, and encodes the op's message's properties in the resulting topic string" ) - def test_passes_down_reauthorize_connection(self, stage, op, mocker): - def run_op(op): - print("in run_op {}".format(op.__class__.__name__)) - if isinstance(op, pipeline_ops_base.UpdateSasTokenOperation): - op.complete(error=None) - else: - pass - - stage.next.run_op = mocker.MagicMock(side_effect=run_op) + def test_telemetry_topic(self, mocker, stage, op, pipeline_config, mock_mqtt_topic): + # Although this requirement refers to message properties, we don't actually have to + # parametrize the op to have them, because the entire logic of encoding message properties + # is handled by the mocked out mqtt_topic_iothub library, so whether or not our fixture + # has message properties on the message or not is irrelevant. stage.run_op(op) - assert stage.next.run_op.call_count == 2 - assert stage.next.run_op.call_args_list[0][0][0] == op - assert isinstance( - stage.next.run_op.call_args_list[1][0][0], - pipeline_ops_base.ReauthorizeConnectionOperation, + assert mock_mqtt_topic.get_telemetry_topic_for_publish.call_count == 1 + assert mock_mqtt_topic.get_telemetry_topic_for_publish.call_args == mocker.call( + device_id=pipeline_config.device_id, module_id=pipeline_config.module_id + ) + assert mock_mqtt_topic.encode_message_properties_in_topic.call_count == 1 + assert mock_mqtt_topic.encode_message_properties_in_topic.call_args == mocker.call( + op.message, mock_mqtt_topic.get_telemetry_topic_for_publish.return_value ) - # CT-TODO: Make this test clearer - this below assertion is a bit confusing - # What is happening here is that the run_op defined above for the mock only completes - # ops of type UpdateSasTokenOperation (i.e. variable 'op'). However, completing the - # op triggers a callback which halts the completion, and then spawn a reauthorize_connection worker op, - # which must be completed before full completion of 'op' can occur. However, as the above - # run_op mock only completes ops of type UpdateSasTokenOperation, this never happens, - # thus op is not completed. - assert not op.completed - # CT-TODO: remove this once able. This test does not have a high degree of accuracy, and its contents - # could be tested better once stage tests are restructured. This test is overlapping with tests of - # worker op functionality, that should not be being tested at this granularity here. @pytest.mark.it( - "Completes the op with success if some lower level stage returns success for the ReauthorizeConnectionOperation" + "Sends a new MQTTPublishOperation down the pipeline with the message data from the original op and the derived topic string" ) - def test_reauthorize_connection_succeeds(self, mocker, stage, next_stage_succeeds, op): - # default is for stage.next.run_op to return success for all ops + def test_sends_mqtt_publish_op_down(self, mocker, stage, op, mock_mqtt_topic): stage.run_op(op) - assert stage.next.run_op.call_count == 2 - assert stage.next.run_op.call_args_list[0][0][0] == op - assert isinstance( - stage.next.run_op.call_args_list[1][0][0], - pipeline_ops_base.ReauthorizeConnectionOperation, - ) + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_mqtt.MQTTPublishOperation) + assert new_op.topic == mock_mqtt_topic.encode_message_properties_in_topic.return_value + assert new_op.payload == op.message.data + + @pytest.mark.it("Completes the original op upon completion of the new MQTTPublishOperation") + def test_complete_resulting_op(self, stage, op, op_error): + stage.run_op(op) + assert not op.completed + + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + + new_op.complete(error=op_error) + + assert new_op.completed + assert new_op.error is op_error assert op.completed - assert op.complete.call_count == 2 # op was completed twice due to an uncompletion + assert op.error is op_error - # most recent call, i.e. one triggered by the successful reauthorize_connection - assert op.complete.call_args == mocker.call(error=None) - # CT-TODO: As above, remove/restructure ASAP +@pytest.mark.describe( + "IoTHubMQTTTranslationStage - .run_op() -- Called with SendOutputMessageOperation" +) +class TestIoTHubMQTTTranslationStageRunOpWithSendOutputMessageOperation( + StageRunOpTestBase, IoTHubMQTTTranslationStageTestConfig +): + @pytest.fixture + def op(self, mocker): + return pipeline_ops_iothub.SendOutputMessageOperation( + message=Message("my message"), callback=mocker.MagicMock() + ) + @pytest.mark.it( - "Completes the op with failure if some lower level stage returns failure for the ReauthorizeConnectionOperation" + "Derives the IoTHub telemetry topic using the device/module details, and encodes the op's message's properties in the resulting topic string" ) - def test_reauthorize_connection_fails(self, stage, op, mocker, arbitrary_exception): - cb = op.callback_stack[0] - - def run_op(op): - print("in run_op {}".format(op.__class__.__name__)) - if isinstance(op, pipeline_ops_base.UpdateSasTokenOperation): - op.complete(error=None) - elif isinstance(op, pipeline_ops_base.ReauthorizeConnectionOperation): - op.complete(error=arbitrary_exception) - else: - pass - - stage.next.run_op = mocker.MagicMock(side_effect=run_op) + def test_telemetry_topic(self, mocker, stage, op, pipeline_config, mock_mqtt_topic): + # Although this requirement refers to message properties, we don't actually have to + # parametrize the op to have them, because the entire logic of encoding message properties + # is handled by the mocked out mqtt_topic_iothub library, so whether or not our fixture + # has message properties on the message or not is irrelevant. stage.run_op(op) - assert stage.next.run_op.call_count == 2 - assert stage.next.run_op.call_args_list[0][0][0] == op - assert isinstance( - stage.next.run_op.call_args_list[1][0][0], - pipeline_ops_base.ReauthorizeConnectionOperation, + assert mock_mqtt_topic.get_telemetry_topic_for_publish.call_count == 1 + assert mock_mqtt_topic.get_telemetry_topic_for_publish.call_args == mocker.call( + device_id=pipeline_config.device_id, module_id=pipeline_config.module_id + ) + assert mock_mqtt_topic.encode_message_properties_in_topic.call_count == 1 + assert mock_mqtt_topic.encode_message_properties_in_topic.call_args == mocker.call( + op.message, mock_mqtt_topic.get_telemetry_topic_for_publish.return_value ) - assert cb.call_count == 1 - assert cb.call_args == mocker.call(op=op, error=arbitrary_exception) - - -basic_ops = [ - { - "op_class": pipeline_ops_iothub.SendD2CMessageOperation, - "op_init_kwargs": {"message": fake_message, "callback": None}, - "new_op_class": pipeline_ops_mqtt.MQTTPublishOperation, - }, - { - "op_class": pipeline_ops_iothub.SendOutputEventOperation, - "op_init_kwargs": {"message": fake_message, "callback": None}, - "new_op_class": pipeline_ops_mqtt.MQTTPublishOperation, - }, - { - "op_class": pipeline_ops_iothub.SendMethodResponseOperation, - "op_init_kwargs": {"method_response": fake_method_response, "callback": None}, - "new_op_class": pipeline_ops_mqtt.MQTTPublishOperation, - }, - { - "op_class": pipeline_ops_base.EnableFeatureOperation, - "op_init_kwargs": {"feature_name": constant.C2D_MSG, "callback": None}, - "new_op_class": pipeline_ops_mqtt.MQTTSubscribeOperation, - }, - { - "op_class": pipeline_ops_base.DisableFeatureOperation, - "op_init_kwargs": {"feature_name": constant.C2D_MSG, "callback": None}, - "new_op_class": pipeline_ops_mqtt.MQTTUnsubscribeOperation, - }, -] - - -# CT-TODO: simplify this -@pytest.mark.parametrize( - "params", - basic_ops, - ids=["{}->{}".format(x["op_class"].__name__, x["new_op_class"].__name__) for x in basic_ops], -) -@pytest.mark.describe("IoTHubMQTTTranslationStage - .run_op() -- called with basic MQTT operations") -class TestIoTHubMQTTConverterBasicOperations(IoTHubMQTTTranslationStageTestBase): - @pytest.fixture - def op(self, params, mocker): - op = params["op_class"](**params["op_init_kwargs"]) - mocker.spy(op, "spawn_worker_op") - return op - @pytest.mark.it("Runs a worker operation on the next stage") - def test_spawn_worker_op(self, params, stage, stages_configured_for_both, op): + @pytest.mark.it( + "Sends a new MQTTPublishOperation down the pipeline with the message data from the original op and the derived topic string" + ) + def test_sends_mqtt_publish_op_down(self, mocker, stage, op, mock_mqtt_topic): stage.run_op(op) - assert op.spawn_worker_op.call_count == 1 - assert op.spawn_worker_op.call_args[1]["worker_op_type"] is params["new_op_class"] - new_op = stage.next._run_op.call_args[0][0] - assert isinstance(new_op, params["new_op_class"]) - - -publish_ops = [ - { - "name": "send telemetry", - "stage_type": "device", - "op_class": pipeline_ops_iothub.SendD2CMessageOperation, - "op_init_kwargs": {"message": Message(fake_message_body), "callback": None}, - "topic": "devices/{}/messages/events/".format(fake_device_id), - "publish_payload": fake_message_body, - }, - { - "name": "send telemetry with content type and content encoding", - "stage_type": "device", - "op_class": pipeline_ops_iothub.SendD2CMessageOperation, - "op_init_kwargs": { - "message": Message( - fake_message_body, - content_type=fake_content_type, - content_encoding=fake_content_encoding, - ), - "callback": None, - }, - "topic": "devices/{}/messages/events/{}&{}".format( - fake_device_id, fake_content_type_encoded, fake_content_encoding_encoded - ), - "publish_payload": fake_message_body, - }, - { - "name": "send telemetry overriding only the content type", - "stage_type": "device", - "op_class": pipeline_ops_iothub.SendD2CMessageOperation, - "op_init_kwargs": { - "message": Message(fake_message_body, content_type=fake_content_type), - "callback": None, - }, - "topic": "devices/{}/messages/events/{}".format(fake_device_id, fake_content_type_encoded), - "publish_payload": fake_message_body, - }, - { - "name": "send telemetry with single system property", - "stage_type": "device", - "op_class": pipeline_ops_iothub.SendD2CMessageOperation, - "op_init_kwargs": { - "message": Message(fake_message_body, output_name=fake_output_name), - "callback": None, - }, - "topic": "devices/{}/messages/events/{}".format(fake_device_id, fake_output_name_encoded), - "publish_payload": fake_message_body, - }, - { - "name": "send security message", - "stage_type": "device", - "op_class": pipeline_ops_iothub.SendD2CMessageOperation, - "op_init_kwargs": {"message": create_security_message(fake_message_body), "callback": None}, - "topic": "devices/{}/messages/events/{}".format( - fake_device_id, security_message_interface_id_encoded - ), - "publish_payload": fake_message_body, - }, - { - "name": "send telemetry with multiple system properties", - "stage_type": "device", - "op_class": pipeline_ops_iothub.SendD2CMessageOperation, - "op_init_kwargs": { - "message": Message( - fake_message_body, message_id=fake_message_id, output_name=fake_output_name - ), - "callback": None, - }, - "topic": "devices/{}/messages/events/{}&{}".format( - fake_device_id, fake_output_name_encoded, fake_message_id_encoded - ), - "publish_payload": fake_message_body, - }, - { - "name": "send telemetry with only single user property", - "stage_type": "device", - "op_class": pipeline_ops_iothub.SendD2CMessageOperation, - "op_init_kwargs": { - "message": create_message_with_user_properties(fake_message_body, is_multiple=False), - "callback": None, - }, - "topic": "devices/{}/messages/events/{}".format( - fake_device_id, fake_message_user_property_1_encoded - ), - "publish_payload": fake_message_body, - }, - { - "name": "send telemetry with only multiple user properties", - "stage_type": "device", - "op_class": pipeline_ops_iothub.SendD2CMessageOperation, - "op_init_kwargs": { - "message": create_message_with_user_properties(fake_message_body, is_multiple=True), - "callback": None, - }, - # For more than 1 user property the order could be different, creating 2 different topics - "topic1": "devices/{}/messages/events/{}&{}".format( - fake_device_id, - fake_message_user_property_1_encoded, - fake_message_user_property_2_encoded, - ), - "topic2": "devices/{}/messages/events/{}&{}".format( - fake_device_id, - fake_message_user_property_2_encoded, - fake_message_user_property_1_encoded, - ), - "publish_payload": fake_message_body, - }, - { - "name": "send telemetry with 1 system and 1 user property", - "stage_type": "device", - "op_class": pipeline_ops_iothub.SendD2CMessageOperation, - "op_init_kwargs": { - "message": create_message_with_system_and_user_properties( - fake_message_body, is_multiple=False - ), - "callback": None, - }, - "topic": "devices/{}/messages/events/{}&{}".format( - fake_device_id, fake_message_id_encoded, fake_message_user_property_1_encoded - ), - "publish_payload": fake_message_body, - }, - { - "name": "send telemetry with multiple system and multiple user properties", - "stage_type": "device", - "op_class": pipeline_ops_iothub.SendD2CMessageOperation, - "op_init_kwargs": { - "message": create_message_with_system_and_user_properties( - fake_message_body, is_multiple=True - ), - "callback": None, - }, - # For more than 1 user property the order could be different, creating 2 different topics - "topic1": "devices/{}/messages/events/{}&{}&{}&{}".format( - fake_device_id, - fake_output_name_encoded, - fake_message_id_encoded, - fake_message_user_property_1_encoded, - fake_message_user_property_2_encoded, - ), - "topic2": "devices/{}/messages/events/{}&{}&{}&{}".format( - fake_device_id, - fake_output_name_encoded, - fake_message_id_encoded, - fake_message_user_property_2_encoded, - fake_message_user_property_1_encoded, - ), - "publish_payload": fake_message_body, - }, - { - "name": "send security message with multiple system and multiple user properties", - "stage_type": "device", - "op_class": pipeline_ops_iothub.SendD2CMessageOperation, - "op_init_kwargs": { - "message": create_security_message_with_system_and_user_properties( - fake_message_body, is_multiple=True - ), - "callback": None, - }, - # For more than 1 user property the order could be different, creating 2 different topics - "topic1": "devices/{}/messages/events/{}&{}&{}&{}&{}".format( - fake_device_id, - fake_output_name_encoded, - fake_message_id_encoded, - security_message_interface_id_encoded, - fake_message_user_property_1_encoded, - fake_message_user_property_2_encoded, - ), - "topic2": "devices/{}/messages/events/{}&{}&{}&{}&{}".format( - fake_device_id, - fake_output_name_encoded, - fake_message_id_encoded, - security_message_interface_id_encoded, - fake_message_user_property_2_encoded, - fake_message_user_property_1_encoded, - ), - "publish_payload": fake_message_body, - }, - { - "name": "send output", - "stage_type": "module", - "op_class": pipeline_ops_iothub.SendOutputEventOperation, - "op_init_kwargs": { - "message": Message(fake_message_body, output_name=fake_output_name), - "callback": None, - }, - "topic": "devices/{}/modules/{}/messages/events/%24.on={}".format( - fake_device_id, fake_module_id, fake_output_name - ), - "publish_payload": fake_message_body, - }, - { - "name": "send output with content type and content encoding", - "stage_type": "module", - "op_class": pipeline_ops_iothub.SendOutputEventOperation, - "op_init_kwargs": { - "message": Message( - fake_message_body, - output_name=fake_output_name, - content_type=fake_content_type, - content_encoding=fake_content_encoding, - ), - "callback": None, - }, - "topic": "devices/{}/modules/{}/messages/events/%24.on={}&{}&{}".format( - fake_device_id, - fake_module_id, - fake_output_name, - fake_content_type_encoded, - fake_content_encoding_encoded, - ), - "publish_payload": fake_message_body, - }, - { - "name": "send output with system properties", - "stage_type": "module", - "op_class": pipeline_ops_iothub.SendOutputEventOperation, - "op_init_kwargs": { - "message": Message( - fake_message_body, message_id=fake_message_id, output_name=fake_output_name - ), - "callback": None, - }, - "topic": "devices/{}/modules/{}/messages/events/%24.on={}&{}".format( - fake_device_id, fake_module_id, fake_output_name, fake_message_id_encoded - ), - "publish_payload": fake_message_body, - }, - { - "name": "send output with only 1 user property", - "stage_type": "module", - "op_class": pipeline_ops_iothub.SendOutputEventOperation, - "op_init_kwargs": { - "message": create_message_for_output_with_user_properties( - fake_message_body, is_multiple=False - ), - "callback": None, - }, - "topic": "devices/{}/modules/{}/messages/events/%24.on={}&{}".format( - fake_device_id, fake_module_id, fake_output_name, fake_message_user_property_1_encoded - ), - "publish_payload": fake_message_body, - }, - { - "name": "send output with only multiple user properties", - "stage_type": "module", - "op_class": pipeline_ops_iothub.SendOutputEventOperation, - "op_init_kwargs": { - "message": create_message_for_output_with_user_properties( - fake_message_body, is_multiple=True - ), - "callback": None, - }, - "topic1": "devices/{}/modules/{}/messages/events/%24.on={}&{}&{}".format( - fake_device_id, - fake_module_id, - fake_output_name, - fake_message_user_property_1_encoded, - fake_message_user_property_2_encoded, - ), - "topic2": "devices/{}/modules/{}/messages/events/%24.on={}&{}&{}".format( - fake_device_id, - fake_module_id, - fake_output_name, - fake_message_user_property_2_encoded, - fake_message_user_property_1_encoded, - ), - "publish_payload": fake_message_body, - }, - { - "name": "send output with 1 system and 1 user property", - "stage_type": "module", - "op_class": pipeline_ops_iothub.SendOutputEventOperation, - "op_init_kwargs": { - "message": create_message_for_output_with_system_and_user_properties( - fake_message_body, is_multiple=False - ), - "callback": None, - }, - "topic": "devices/{}/modules/{}/messages/events/%24.on={}&{}&{}".format( - fake_device_id, - fake_module_id, - fake_output_name, - fake_message_id_encoded, - fake_message_user_property_1_encoded, - ), - "publish_payload": fake_message_body, - }, - { - "name": "send method result", - "stage_type": "both", - "op_class": pipeline_ops_iothub.SendMethodResponseOperation, - "op_init_kwargs": {"method_response": fake_method_response, "callback": None}, - "topic": "$iothub/methods/res/__fake_method_status__/?$rid=__fake_request_id__", - "publish_payload": json.dumps(fake_method_payload), - }, -] - - -@pytest.mark.parametrize("params", publish_ops, ids=[x["name"] for x in publish_ops]) -@pytest.mark.describe("IoTHubMQTTTranslationStage - .run_op() -- called with publish operations") -class TestIoTHubMQTTConverterForPublishOps(IoTHubMQTTTranslationStageTestBase): - @pytest.fixture - def op(self, params, mocker): - op = params["op_class"](**params["op_init_kwargs"]) - op.callback = mocker.MagicMock() - return op - - @pytest.mark.it("Uses the correct topic and encodes message properties string when publishing") - def test_uses_device_topic_for_devices(self, stage, stages_configured_for_both, params, op): - if params["stage_type"] == "device" and stage.module_id: - pytest.skip() - elif params["stage_type"] == "module" and not stage.module_id: - pytest.skip() - stage.run_op(op) - new_op = stage.next._run_op.call_args[0][0] - if "multiple user properties" in params["name"]: - assert new_op.topic == params["topic1"] or new_op.topic == params["topic2"] - else: - assert new_op.topic == params["topic"] - - @pytest.mark.it("Sends the body in the payload of the MQTT publish operation") - def test_sends_correct_body(self, stage, stages_configured_for_both, params, op): + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_mqtt.MQTTPublishOperation) + assert new_op.topic == mock_mqtt_topic.encode_message_properties_in_topic.return_value + assert new_op.payload == op.message.data + + @pytest.mark.it("Completes the original op upon completion of the new MQTTPublishOperation") + def test_complete_resulting_op(self, stage, op, op_error): stage.run_op(op) - new_op = stage.next._run_op.call_args[0][0] - assert new_op.payload == params["publish_payload"] - - -feature_name_to_subscribe_topic = [ - { - "stage_type": "device", - "feature_name": constant.C2D_MSG, - "topic": "devices/{}/messages/devicebound/#".format(fake_device_id), - }, - { - "stage_type": "module", - "feature_name": constant.INPUT_MSG, - "topic": "devices/{}/modules/{}/inputs/#".format(fake_device_id, fake_module_id), - }, - {"stage_type": "both", "feature_name": constant.METHODS, "topic": "$iothub/methods/POST/#"}, -] - -sub_unsub_operations = [ - { - "op_class": pipeline_ops_base.EnableFeatureOperation, - "new_op": pipeline_ops_mqtt.MQTTSubscribeOperation, - }, - { - "op_class": pipeline_ops_base.DisableFeatureOperation, - "new_op": pipeline_ops_mqtt.MQTTUnsubscribeOperation, - }, -] + assert not op.completed + + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + + new_op.complete(error=op_error) + + assert new_op.completed + assert new_op.error is op_error + assert op.completed + assert op.error is op_error @pytest.mark.describe( - "IoTHubMQTTTranslationStage - .run_op() -- called with EnableFeature or DisableFeature" + "IoTHubMQTTTranslationStage - .run_op() -- Called with SendMethodResponseOperation" ) -class TestIoTHubMQTTConverterWithEnableFeature(IoTHubMQTTTranslationStageTestBase): - @pytest.mark.parametrize( - "topic_parameters", - feature_name_to_subscribe_topic, - ids=[ - "{} {}".format(x["stage_type"], x["feature_name"]) - for x in feature_name_to_subscribe_topic - ], - ) - @pytest.mark.parametrize( - "op_parameters", - sub_unsub_operations, - ids=[x["op_class"].__name__ for x in sub_unsub_operations], - ) - @pytest.mark.it("Converts the feature_name to the correct topic") - def test_converts_feature_name_to_topic( - self, mocker, stage, stages_configured_for_both, topic_parameters, op_parameters - ): - if topic_parameters["stage_type"] == "device" and stage.module_id: - pytest.skip() - elif topic_parameters["stage_type"] == "module" and not stage.module_id: - pytest.skip() - stage.next._run_op = mocker.Mock() - op = op_parameters["op_class"]( - feature_name=topic_parameters["feature_name"], callback=mocker.MagicMock() +class TestIoTHubMQTTTranslationStageWithSendMethodResponseOperation( + StageRunOpTestBase, IoTHubMQTTTranslationStageTestConfig +): + @pytest.fixture + def op(self, mocker): + method_response = MethodResponse( + request_id="fake_request_id", status=200, payload={"some": "json"} + ) + return pipeline_ops_iothub.SendMethodResponseOperation( + method_response=method_response, callback=mocker.MagicMock() ) + + @pytest.mark.it("Derives the IoTHub telemetry topic using the op's request id and status") + def test_telemtry_topic(self, mocker, stage, op, mock_mqtt_topic): stage.run_op(op) - new_op = stage.next._run_op.call_args[0][0] - assert isinstance(new_op, op_parameters["new_op"]) - assert new_op.topic == topic_parameters["topic"] - @pytest.mark.it("Fails on an invalid feature_name") + assert mock_mqtt_topic.get_method_topic_for_publish.call_count == 1 + assert mock_mqtt_topic.get_method_topic_for_publish.call_args == mocker.call( + op.method_response.request_id, op.method_response.status + ) + + @pytest.mark.it( + "Sends a new MQTTPublishOperation down the pipeline with the original op's payload in JSON string format, and the derived topic string" + ) @pytest.mark.parametrize( - "op_parameters", - sub_unsub_operations, - ids=[x["op_class"].__name__ for x in sub_unsub_operations], + "payload, expected_string", + [ + pytest.param(None, "null", id="No payload"), + pytest.param({"some": "json"}, '{"some": "json"}', id="Dictionary payload"), + pytest.param("payload", '"payload"', id="String payload"), + ], ) - def test_fails_on_invalid_feature_name( - self, mocker, stage, stages_configured_for_both, op_parameters + def test_sends_mqtt_publish_op_down( + self, mocker, stage, op, mock_mqtt_topic, payload, expected_string ): - op = op_parameters["op_class"]( - feature_name=invalid_feature_name, callback=mocker.MagicMock() - ) - mocker.spy(op, "complete") + op.method_response.payload = payload stage.run_op(op) - assert op.complete.call_count == 1 - assert isinstance(op.complete.call_args[1]["error"], KeyError) - # assert_callback_failed(op=op, error=KeyError) - -@pytest.fixture -def add_pipeline_root(stage, mocker): - root = pipeline_stages_base.PipelineRootStage(mocker.MagicMock()) - mocker.spy(root, "handle_pipeline_event") - stage.previous = root - stage.pipeline_root = root + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_mqtt.MQTTPublishOperation) + assert new_op.topic == mock_mqtt_topic.get_method_topic_for_publish.return_value + assert new_op.payload == expected_string + @pytest.mark.it("Completes the original op upon completion of the new MQTTPublishOperation") + def test_complete_resulting_op(self, stage, op, op_error): + stage.run_op(op) + assert not op.completed -@pytest.mark.describe( - "IoTHubMQTTTranslationStage - .handle_pipeline_event() -- called with unmatched topic" -) -class TestIoTHubMQTTConverterHandlePipelineEvent(IoTHubMQTTTranslationStageTestBase): - @pytest.mark.it("Passes up any mqtt messages with topics that aren't matched by this stage") - def test_passes_up_mqtt_message_with_unknown_topic( - self, stage, stages_configured_for_both, add_pipeline_root, mocker - ): - event = pipeline_events_mqtt.IncomingMQTTMessageEvent( - topic=unmatched_mqtt_topic, payload=fake_mqtt_payload - ) - stage.handle_pipeline_event(event) - assert stage.previous.handle_pipeline_event.call_count == 1 - assert stage.previous.handle_pipeline_event.call_args == mocker.call(event) + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + new_op.complete(error=op_error) -@pytest.fixture -def c2d_event(): - return pipeline_events_mqtt.IncomingMQTTMessageEvent( - topic=fake_c2d_topic, payload=fake_mqtt_payload - ) + assert new_op.completed + assert new_op.error is op_error + assert op.completed + assert op.error is op_error @pytest.mark.describe( - "IoTHubMQTTTranslationStage - .handle_pipeline_event() -- called with C2D topic" + "IoTHubMQTTTranslationStage - .run_op() -- Called with EnableFeatureOperation" ) -class TestIoTHubMQTTConverterHandlePipelineEventC2D(IoTHubMQTTTranslationStageTestBase): +class TestIoTHubMQTTTranslationStageRunOpWithEnableFeatureOperation( + StageRunOpTestBase, IoTHubMQTTTranslationStageTestConfig +): + @pytest.fixture + def op(self, mocker, iothub_pipeline_feature): + return pipeline_ops_base.EnableFeatureOperation( + feature_name=iothub_pipeline_feature, callback=mocker.MagicMock() + ) + @pytest.mark.it( - "Converts mqtt message with topic devices/device_id/message/devicebound/ to c2d event" + "Sends a new MQTTSubscribeOperation down the pipeline, containing the subscription topic string corresponding to the feature being enabled" ) - def test_converts_c2d_topic_to_c2d_events( - self, mocker, stage, stage_configured_for_device, add_pipeline_root, c2d_event + def test_mqtt_subscribe_sent_down( + self, op, stage, expected_mqtt_topic_fn, expected_mqtt_topic_fn_call ): - stage.handle_pipeline_event(c2d_event) - assert stage.previous.handle_pipeline_event.call_count == 1 - new_event = stage.previous.handle_pipeline_event.call_args[0][0] - assert isinstance(new_event, pipeline_events_iothub.C2DMessageEvent) + stage.run_op(op) - @pytest.mark.it("Convers the mqtt payload of a c2d message into a Message object") - def test_creates_message_object_for_c2d_event( - self, mocker, stage, stage_configured_for_device, add_pipeline_root, c2d_event - ): - stage.handle_pipeline_event(c2d_event) - new_event = stage.previous.handle_pipeline_event.call_args[0][0] - assert isinstance(new_event.message, Message) + # Topic was derived as expected + assert expected_mqtt_topic_fn.call_count == 1 + assert expected_mqtt_topic_fn.call_args == expected_mqtt_topic_fn_call - @pytest.mark.it("Extracts message properties from the mqtt topic for c2d messages") - def test_extracts_c2d_message_properties_from_topic_name( - self, mocker, stage, stage_configured_for_device, add_pipeline_root - ): - event = pipeline_events_mqtt.IncomingMQTTMessageEvent( - topic=fake_c2d_topic_with_content_type, payload=fake_mqtt_payload + # New op was sent down + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_mqtt.MQTTSubscribeOperation) + + # New op has the expected topic + assert new_op.topic == expected_mqtt_topic_fn.return_value + + @pytest.mark.it("Completes the original op upon completion of the new MQTTSubscribeOperation") + def test_complete_resulting_op(self, stage, op, op_error): + stage.run_op(op) + assert not op.completed + + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + + new_op.complete(error=op_error) + + assert new_op.completed + assert new_op.error is op_error + assert op.completed + assert op.error is op_error + + +@pytest.mark.describe( + "IoTHubMQTTTranslationStage - .run_op() -- Called with DisableFeatureOperation" +) +class TestIoTHubMQTTTranslationStageRunOpWithDisableFeatureOperation( + StageRunOpTestBase, IoTHubMQTTTranslationStageTestConfig +): + @pytest.fixture + def op(self, mocker, iothub_pipeline_feature): + return pipeline_ops_base.DisableFeatureOperation( + feature_name=iothub_pipeline_feature, callback=mocker.MagicMock() ) - stage.handle_pipeline_event(event) - new_event = stage.previous.handle_pipeline_event.call_args[0][0] - assert new_event.message.content_type == fake_content_type - @pytest.mark.it("Passes up c2d messages destined for another device") - def test_if_topic_is_c2d_for_another_device( - self, mocker, stage, stage_configured_for_device, add_pipeline_root + @pytest.mark.it( + "Sends a new MQTTUnsubscribeOperation down the pipeline, containing the subscription topic string corresponding to the feature being disabled" + ) + def test_mqtt_unsubscribe_sent_down( + self, op, stage, expected_mqtt_topic_fn, expected_mqtt_topic_fn_call ): - event = pipeline_events_mqtt.IncomingMQTTMessageEvent( - topic=fake_c2d_topic_for_another_device, payload=fake_mqtt_payload - ) - stage.handle_pipeline_event(event) - assert stage.previous.handle_pipeline_event.call_count == 1 - assert stage.previous.handle_pipeline_event.call_args == mocker.call(event) + stage.run_op(op) + # Topic was derived as expected + assert expected_mqtt_topic_fn.call_count == 1 + assert expected_mqtt_topic_fn.call_args == expected_mqtt_topic_fn_call -@pytest.mark.describe("IotHubMQTTConverter - .run_op() -- called with RequestOperation") -class TestIotHubMQTTConverterWithSendIotRequest(IoTHubMQTTTranslationStageTestBase): - @pytest.fixture - def fake_request_type(self): - return "twin" + # New op was sent down + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_mqtt.MQTTUnsubscribeOperation) - @pytest.fixture - def fake_method(self): - return "__fake_method__" + # New op has the expected topic + assert new_op.topic == expected_mqtt_topic_fn.return_value - @pytest.fixture - def fake_resource_location(self): - return "__fake_resource_location__" + @pytest.mark.it("Completes the original op upon completion of the new MQTTUnsubscribeOperation") + def test_complete_resulting_op(self, stage, op, op_error): + stage.run_op(op) + assert not op.completed - @pytest.fixture - def fake_request_body(self): - return "__fake_request_body__" + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] - @pytest.fixture - def fake_request_body_as_string(self, fake_request_body): - return json.dumps(fake_request_body) + new_op.complete(error=op_error) - @pytest.fixture - def fake_request_id(self): - return "__fake_request_id__" + assert new_op.completed + assert new_op.error is op_error + assert op.completed + assert op.error is op_error + +@pytest.mark.describe("IoTHubMQTTTranslationStage - .run_op() -- Called with RequestOperation") +class TestIoTHubMQTTTranslationStageWithRequestOperation( + StageRunOpTestBase, IoTHubMQTTTranslationStageTestConfig +): @pytest.fixture - def op( - self, - fake_request_type, - fake_method, - fake_resource_location, - fake_request_body, - fake_request_id, - mocker, - ): - op = pipeline_ops_base.RequestOperation( - request_type=fake_request_type, - method=fake_method, - resource_location=fake_resource_location, - request_body=fake_request_body, - request_id=fake_request_id, + def op(self, mocker): + # Only request operation supported at present by this stage is TWIN. If this changes, + # logic in this whole test class must become more robust + return pipeline_ops_base.RequestOperation( + request_type=constant.TWIN, + method="GET", + resource_location="/", + request_body=" ", + request_id="fake_request_id", callback=mocker.MagicMock(), ) - mocker.spy(op, "complete") - mocker.spy(op, "spawn_worker_op") - return op - - @pytest.mark.it("calls the op callback with an OperationError if request_type is not 'twin'") - def test_sends_bad_request_type(self, stage, op): - op.request_type = "not_twin" - stage.run_op(op) - assert op.complete.call_count == 1 - assert isinstance(op.complete.call_args[1]["error"], OperationError) @pytest.mark.it( - "Runs an MQTTPublishOperation as a worker op on the next stage with the topic formated as '$iothub/twin/{method}{resource_location}?$rid={request_id}' and the payload as the request_body" + "Derives the IoTHub Twin Request topic using the op's details, if the op is a Twin Request" ) - def test_sends_new_operation( - self, stage, op, fake_method, fake_resource_location, fake_request_id, fake_request_body - ): + def test_twin_request_topic(self, mocker, stage, op, mock_mqtt_topic): stage.run_op(op) - assert op.spawn_worker_op.call_count == 1 - assert ( - op.spawn_worker_op.call_args[1]["worker_op_type"] - is pipeline_ops_mqtt.MQTTPublishOperation - ) - assert stage.next.run_op.call_count == 1 - worker_op = stage.next.run_op.call_args[0][0] - assert isinstance(worker_op, pipeline_ops_mqtt.MQTTPublishOperation) - assert ( - worker_op.topic - == "$iothub/twin/{method}{resource_location}?$rid={request_id}".format( - method=fake_method, - resource_location=fake_resource_location, - request_id=fake_request_id, - ) - ) - assert worker_op.payload == fake_request_body + assert mock_mqtt_topic.get_twin_topic_for_publish.call_count == 1 + assert mock_mqtt_topic.get_twin_topic_for_publish.call_args == mocker.call( + method=op.method, resource_location=op.resource_location, request_id=op.request_id + ) -@pytest.fixture -def input_message_event(): - return pipeline_events_mqtt.IncomingMQTTMessageEvent( - topic=fake_input_message_topic, payload=fake_mqtt_payload + @pytest.mark.it( + "Completes the operation with an OperationError failure if the op is any type of request other than a Twin Request" ) + def test_invalid_op(self, mocker, stage, op): + # Okay, so technically this does'nt prove it does this if it's ANY other type of request, but that's pretty much + # impossible to disprove in a black-box test, because there are infinite possibilities in theory + op.request_type = "Some_other_type" + stage.run_op(op) + assert op.completed + assert isinstance(op.error, OperationError) - -@pytest.mark.describe( - "IoTHubMQTTTranslationStage - .handle_pipeline_event() -- called with input message topic" -) -class TestIoTHubMQTTConverterHandlePipelineEventInputMessages(IoTHubMQTTTranslationStageTestBase): @pytest.mark.it( - "Converts mqtt message with topic devices/device_id/modules/module_id/inputs/input_name/ to input event" + "Sends a new MQTTPublishOperation down the pipeline with the original op's request body and the derived topic string" ) - def test_converts_input_topic_to_input_event( - self, mocker, stage, stage_configured_for_module, add_pipeline_root, input_message_event - ): - stage.handle_pipeline_event(input_message_event) - assert stage.previous.handle_pipeline_event.call_count == 1 - new_event = stage.previous.handle_pipeline_event.call_args[0][0] - assert isinstance(new_event, pipeline_events_iothub.InputMessageEvent) + def test_sends_mqtt_publish_op_down(self, mocker, stage, op, mock_mqtt_topic): + stage.run_op(op) - @pytest.mark.it("Converts the mqtt payload of an input message into a Message object") - def test_creates_message_object_for_input_event( - self, mocker, stage, stage_configured_for_module, add_pipeline_root, input_message_event - ): - stage.handle_pipeline_event(input_message_event) - new_event = stage.previous.handle_pipeline_event.call_args[0][0] - assert isinstance(new_event.message, Message) + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_mqtt.MQTTPublishOperation) + assert new_op.topic == mock_mqtt_topic.get_twin_topic_for_publish.return_value + assert new_op.payload == op.request_body - @pytest.mark.it("Extracts the input name of an input message from the mqtt topic") - def test_extracts_input_name_from_topic( - self, mocker, stage, stage_configured_for_module, add_pipeline_root, input_message_event - ): - stage.handle_pipeline_event(input_message_event) - new_event = stage.previous.handle_pipeline_event.call_args[0][0] - assert new_event.input_name == fake_input_name + @pytest.mark.it("Completes the original op upon completion of the new MQTTPublishOperation") + def test_complete_resulting_op(self, stage, op, op_error): + stage.run_op(op) + assert not op.completed - @pytest.mark.it("Extracts message properties from the mqtt topic for input messages") - def test_extracts_input_message_properties_from_topic_name( - self, mocker, stage, stage_configured_for_module, add_pipeline_root - ): - event = pipeline_events_mqtt.IncomingMQTTMessageEvent( - topic=fake_input_message_topic_with_content_type, payload=fake_mqtt_payload - ) - stage.handle_pipeline_event(event) - new_event = stage.previous.handle_pipeline_event.call_args[0][0] - assert new_event.message.content_type == fake_content_type + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] - @pytest.mark.parametrize( - "topic", - [fake_input_message_topic_for_another_device, fake_input_message_topic_for_another_module], - ids=["different device_id", "same device_id"], - ) - @pytest.mark.it("Passes up input messages destined for another module") - def test_if_topic_is_input_message_for_another_module( - self, mocker, stage, stage_configured_for_module, add_pipeline_root, topic - ): - event = pipeline_events_mqtt.IncomingMQTTMessageEvent( - topic=topic, payload=fake_mqtt_payload - ) - stage.handle_pipeline_event(event) - assert stage.previous.handle_pipeline_event.call_count == 1 - assert stage.previous.handle_pipeline_event.call_args == mocker.call(event) + new_op.complete(error=op_error) + assert new_op.completed + assert new_op.error is op_error + assert op.completed + assert op.error is op_error -@pytest.fixture -def method_request_event(): - return pipeline_events_mqtt.IncomingMQTTMessageEvent( - topic=fake_method_request_topic, payload=fake_method_request_payload - ) + +@pytest.mark.describe( + "IoTHubMQTTTranslationStage - .run_op() -- Called with other arbitrary operation" +) +class TestIoTHubMQTTTranslationStageRunOpWithAribtraryOperation( + StageRunOpTestBase, IoTHubMQTTTranslationStageTestConfig +): + @pytest.fixture + def op(self, arbitrary_op): + return arbitrary_op + + @pytest.mark.it("Sends the operation down the pipeline") + def test_sends_op_down(self, mocker, stage, op): + stage.run_op(op) + + assert stage.send_op_down.call_count == 1 + assert stage.send_op_down.call_args == mocker.call(op) @pytest.mark.describe( - "IoTHubMQTTTranslationStage - .handle_pipeline_event() -- called with method request topic" + "IoTHubMQTTTranslationStage - .handle_pipeline_event() -- Called with IncomingMQTTMessageEvent (C2D topic string)" ) -class TestIoTHubMQTTConverterHandlePipelineEventMethodRequets(IoTHubMQTTTranslationStageTestBase): +class TestIoTHubMQTTTranslationStageHandlePipelineEventWithIncomingMQTTMessageEventC2DTopic( + StageHandlePipelineEventTestBase, IoTHubMQTTTranslationStageTestConfig +): + @pytest.fixture + def event(self, pipeline_config): + # topic device id MATCHES THE PIPELINE CONFIG + topic = "devices/{device_id}/messages/devicebound/%24.mid=6b822696-f75a-46f5-8b02-0680db65abf5&%24.to=%2Fdevices%2F{device_id}%2Fmessages%2Fdevicebound".format( + device_id=pipeline_config.device_id + ) + return pipeline_events_mqtt.IncomingMQTTMessageEvent(topic=topic, payload="some payload") + @pytest.mark.it( - "Converts mqtt messages with topic $iothub/methods/POST/{method name}/?$rid={request id} to method request events" + "Creates a Message with the event's payload, and applies any message properties included in the topic" ) - def test_converts_method_request_topic_to_method_request_event( - self, mocker, stage, stages_configured_for_both, add_pipeline_root, method_request_event - ): - stage.handle_pipeline_event(method_request_event) - assert stage.previous.handle_pipeline_event.call_count == 1 - new_event = stage.previous.handle_pipeline_event.call_args[0][0] - assert isinstance(new_event, pipeline_events_iothub.MethodRequestEvent) + def test_message(self, event, stage, mock_mqtt_topic): + stage.handle_pipeline_event(event) - @pytest.mark.it("Makes a MethodRequest object to hold the method request details") - def test_passes_method_request_object_in_method_request_event( - self, mocker, stage, stages_configured_for_both, add_pipeline_root, method_request_event - ): - stage.handle_pipeline_event(method_request_event) - new_event = stage.previous.handle_pipeline_event.call_args[0][0] - assert isinstance(new_event.method_request, MethodRequest) + # Message properties were extracted from the topic + # NOTE that because this is mocked, we don't need to test various topics with various properties + assert mock_mqtt_topic.extract_message_properties_from_topic.call_count == 1 + assert mock_mqtt_topic.extract_message_properties_from_topic.call_args[0][0] == event.topic + message = mock_mqtt_topic.extract_message_properties_from_topic.call_args[0][1] + assert isinstance(message, Message) + # The message contains the event's payload + assert message.data == event.payload - @pytest.mark.it("Extracts the method name from the mqtt topic") - def test_extracts_method_name_from_method_request_topic( - self, mocker, stage, stages_configured_for_both, add_pipeline_root, method_request_event - ): - stage.handle_pipeline_event(method_request_event) - new_event = stage.previous.handle_pipeline_event.call_args[0][0] - assert new_event.method_request.name == fake_method_name + @pytest.mark.it( + "Sends a new C2DMessageEvent up the pipeline, containing the newly created Message" + ) + def test_c2d_message_event(self, event, stage, mock_mqtt_topic): + stage.handle_pipeline_event(event) - @pytest.mark.it("Extracts the request id from the mqtt topic") - def test_extracts_request_id_from_method_request_topic( - self, mocker, stage, stages_configured_for_both, add_pipeline_root, method_request_event - ): - stage.handle_pipeline_event(method_request_event) - new_event = stage.previous.handle_pipeline_event.call_args[0][0] - assert new_event.method_request.request_id == fake_request_id + # C2DMessageEvent was sent up the pipeline + assert stage.send_event_up.call_count == 1 + new_event = stage.send_event_up.call_args[0][0] + assert isinstance(new_event, pipeline_events_iothub.C2DMessageEvent) + # The C2DMessageEvent contains the same Message that was created from the topic details + assert mock_mqtt_topic.extract_message_properties_from_topic.call_count == 1 + message = mock_mqtt_topic.extract_message_properties_from_topic.call_args[0][1] + assert new_event.message is message @pytest.mark.it( - "Puts the payload of the mqtt message as the payload of the method requets object" + "Sends the original event up the pipeline instead, if the device id in the topic string does not match the client details" ) - def test_puts_mqtt_payload_in_method_request_payload( - self, mocker, stage, stages_configured_for_both, add_pipeline_root, method_request_event - ): - stage.handle_pipeline_event(method_request_event) - new_event = stage.previous.handle_pipeline_event.call_args[0][0] - assert new_event.method_request.payload == json.loads( - fake_method_request_payload.decode("utf-8") - ) + def test_nonmatching_device_id(self, mocker, event, stage): + stage.pipeline_root.pipeline_configuration.device_id = "different_device_id" + stage.handle_pipeline_event(event) + + assert stage.send_event_up.call_count == 1 + assert stage.send_event_up.call_args == mocker.call(event) @pytest.mark.describe( - "IotHubMQTTConverter - .handle_pipeline_event() -- called with twin response topic" + "IoTHubMQTTTranslationStage - .handle_pipeline_event() -- Called with IncomingMQTTMessageEvent (Input Message topic string)" ) -class TestIotHubMQTTConverterHandlePipelineEventTwinResponse(IoTHubMQTTTranslationStageTestBase): - @pytest.fixture - def fake_request_id(self): - return "__fake_request_id__" - +class TestIoTHubMQTTTranslationStageHandlePipelineEventWithIncomingMQTTMessageEventInputTopic( + StageHandlePipelineEventTestBase, IoTHubMQTTTranslationStageTestConfig +): @pytest.fixture - def fake_status_code(self): - return 200 + def pipeline_config(self, mocker): + cfg = config.IoTHubPipelineConfig( + hostname="fake_hostname", + device_id="my_device", + module_id="my_module", + sastoken=mocker.MagicMock(), + ) + return cfg @pytest.fixture - def bad_status_code(self): - return "__bad_status_code__" + def input_name(self): + return "some_input" @pytest.fixture - def fake_topic_name(self, fake_request_id, fake_status_code): - return "$iothub/twin/res/{status_code}/?$rid={request_id}".format( - status_code=fake_status_code, request_id=fake_request_id + def event(self, pipeline_config, input_name): + # topic device id MATCHES THE PIPELINE CONFIG + topic = "devices/{device_id}/modules/{module_id}/inputs/{input_name}/%24.mid=6b822696-f75a-46f5-8b02-0680db65abf5&%24.to=%2Fdevices%2F{device_id}%2Fmodules%2F{module_id}%2Finputs%2F{input_name}".format( + device_id=pipeline_config.device_id, + module_id=pipeline_config.module_id, + input_name=input_name, ) + return pipeline_events_mqtt.IncomingMQTTMessageEvent(topic=topic, payload="some payload") - @pytest.fixture - def fake_topic_name_with_missing_request_id(self, fake_status_code): - return "$iothub/twin/res/{status_code}".format(status_code=fake_status_code) + @pytest.mark.it( + "Creates a Message with the event's payload, and applies any message properties included in the topic" + ) + def test_message(self, event, stage, mock_mqtt_topic): + stage.handle_pipeline_event(event) - @pytest.fixture - def fake_topic_name_with_missing_status_code(self, fake_request_id): - return "$iothub/twin/res/?$rid={request_id}".format(request_id=fake_request_id) + # Message properties were extracted from the topic + # NOTE that because this is mocked, we don't need to test various topics with various properties + assert mock_mqtt_topic.extract_message_properties_from_topic.call_count == 1 + assert mock_mqtt_topic.extract_message_properties_from_topic.call_args[0][0] == event.topic + message = mock_mqtt_topic.extract_message_properties_from_topic.call_args[0][1] + assert isinstance(message, Message) + # The message contains the event's payload + assert message.data == event.payload + + @pytest.mark.it( + "Sends a new InputMessageEvent up the pipeline, containing the newly created Message and the input name extracted from the topic" + ) + def test_input_message_event(self, event, stage, mock_mqtt_topic, input_name): + stage.handle_pipeline_event(event) + + # InputMessageEvent was sent up the pipeline + assert stage.send_event_up.call_count == 1 + new_event = stage.send_event_up.call_args[0][0] + assert isinstance(new_event, pipeline_events_iothub.InputMessageEvent) + # The InputMessageEvent contains the same Message that was created from the topic details + assert mock_mqtt_topic.extract_message_properties_from_topic.call_count == 1 + message = mock_mqtt_topic.extract_message_properties_from_topic.call_args[0][1] + assert new_event.message is message + # The InputMessageEvent contains the same input name from the topic + assert new_event.input_name == input_name + @pytest.mark.it( + "Sends the original event up the pipeline instead, if the the topic string does not match the client details" + ) + @pytest.mark.parametrize( + "alt_device_id, alt_module_id", + [ + pytest.param("different_device_id", None, id="Non-matching device id"), + pytest.param(None, "different_module_id", id="Non-matching module id"), + pytest.param( + "different_device_id", + "different_module_id", + id="Non-matching device id AND module id", + ), + ], + ) + def test_nonmatching_ids(self, mocker, event, stage, alt_device_id, alt_module_id): + if alt_device_id: + stage.pipeline_root.pipeline_configuration.device_id = alt_device_id + if alt_module_id: + stage.pipeline_root.pipeline_configuration.module_id = alt_module_id + stage.handle_pipeline_event(event) + + assert stage.send_event_up.call_count == 1 + assert stage.send_event_up.call_args == mocker.call(event) + + +@pytest.mark.describe( + "IoTHubMQTTTranslationStage - .handle_pipeline_event() -- Called with IncomingMQTTMessageEvent (Method topic string)" +) +class TestIoTHubMQTTTranslationStageHandlePipelineEventWithIncomingMQTTMessageEventMethodTopic( + StageHandlePipelineEventTestBase, IoTHubMQTTTranslationStageTestConfig +): @pytest.fixture - def fake_topic_name_with_bad_status_code(self, fake_request_id, bad_status_code): - return "$iothub/twin/res/{status_code}/?$rid={request_id}".format( - request_id=fake_request_id, status_code=bad_status_code - ) + def method_name(self): + return "some_method" @pytest.fixture - def fake_payload(self): - return "__fake_payload__" + def rid(self): + return "1" @pytest.fixture - def fake_event(self, fake_topic_name, fake_payload): + def event(self, method_name, rid): + topic = "$iothub/methods/POST/{method_name}/?$rid={rid}".format( + method_name=method_name, rid=rid + ) return pipeline_events_mqtt.IncomingMQTTMessageEvent( - topic=fake_topic_name, payload=fake_payload + topic=topic, payload=b'{"some": "json"}' ) - @pytest.fixture - def fixup_stage_for_test(self, stage, add_pipeline_root): - print("Adding module") - stage.module_id = fake_module_id - stage.device_id = fake_device_id - @pytest.mark.it( - "Calls .handle_pipeline_event() on the previous stage with an ResponseEvent, with request_id and status_code as attributes extracted from the topic and the response_body attirbute set to the payload" + "Sends a MethodRequestEvent up the pipeline with a MethodRequest containing values extracted from the event's topic" ) - def test_extracts_request_id_status_code_and_payload( - self, - stage, - fixup_stage_for_test, - fake_request_id, - fake_status_code, - fake_payload, - fake_event, - ): - stage.handle_pipeline_event(event=fake_event) - assert stage.previous.handle_pipeline_event.call_count == 1 - new_event = stage.previous.handle_pipeline_event.call_args[0][0] - assert isinstance(new_event, pipeline_events_base.ResponseEvent) - assert new_event.status_code == fake_status_code - assert new_event.request_id == fake_request_id - assert new_event.response_body == fake_payload + def test_method_request(self, event, stage, method_name, rid): + stage.handle_pipeline_event(event) - @pytest.mark.it( - "Calls the unhandled exception handler with a PipelineError if there is no previous stage" - ) - def test_no_previous_stage( - self, stage, fixup_stage_for_test, fake_event, unhandled_error_handler - ): - stage.previous = None - stage.handle_pipeline_event(fake_event) - assert unhandled_error_handler.call_count == 1 - assert isinstance(unhandled_error_handler.call_args[0][0], PipelineError) + assert stage.send_event_up.call_count == 1 + new_event = stage.send_event_up.call_args[0][0] + assert isinstance(new_event, pipeline_events_iothub.MethodRequestEvent) + assert isinstance(new_event.method_request, MethodRequest) + assert new_event.method_request.name == method_name + assert new_event.method_request.request_id == rid + # This is expanded on in in the next test + assert new_event.method_request.payload == json.loads(event.payload.decode("utf-8")) @pytest.mark.it( - "Calls the unhandled exception handler if the requet_id is missing from the topic name" + "Derives the MethodRequestEvent's payload by converting the original event's payload from bytes into a JSON object" ) - def test_invalid_topic_with_missing_request_id( - self, - stage, - fixup_stage_for_test, - fake_event, - fake_topic_name_with_missing_request_id, - unhandled_error_handler, - ): - fake_event.topic = fake_topic_name_with_missing_request_id - stage.handle_pipeline_event(event=fake_event) - assert unhandled_error_handler.call_count == 1 - assert isinstance(unhandled_error_handler.call_args[0][0], IndexError) - - @pytest.mark.it( - "Calls the unhandled exception handler if the status code is missing from the topic name" + @pytest.mark.parametrize( + "original_payload, derived_payload", + [ + pytest.param(b'{"some": "payload"}', {"some": "payload"}, id="Dictionary JSON"), + pytest.param(b'"payload"', "payload", id="String JSON"), + pytest.param(b"1234", 1234, id="Int JSON"), + pytest.param(b"null", None, id="None JSON"), + ], ) - def test_invlid_topic_with_missing_status_code( - self, - stage, - fixup_stage_for_test, - fake_event, - fake_topic_name_with_missing_status_code, - unhandled_error_handler, - ): - fake_event.topic = fake_topic_name_with_missing_status_code - stage.handle_pipeline_event(event=fake_event) - assert unhandled_error_handler.call_count == 1 - assert isinstance(unhandled_error_handler.call_args[0][0], ValueError) + def test_json_payload(self, event, stage, original_payload, derived_payload): + event.payload = original_payload + stage.handle_pipeline_event(event) - @pytest.mark.it( - "Calls the unhandled exception handler if the status code in the topic name is not numeric" - ) - def test_invlid_topic_with_bad_status_code( - self, - stage, - fixup_stage_for_test, - fake_event, - fake_topic_name_with_bad_status_code, - unhandled_error_handler, - ): - fake_event.topic = fake_topic_name_with_bad_status_code - stage.handle_pipeline_event(event=fake_event) - assert unhandled_error_handler.call_count == 1 - assert isinstance(unhandled_error_handler.call_args[0][0], ValueError) + assert stage.send_event_up.call_count == 1 + new_event = stage.send_event_up.call_args[0][0] + assert isinstance(new_event, pipeline_events_iothub.MethodRequestEvent) + assert isinstance(new_event.method_request, MethodRequest) + + assert new_event.method_request.payload == derived_payload @pytest.mark.describe( - "IotHubMQTTConverter - .handle_pipeline_event() -- called with twin patch topic" + "IoTHubMQTTTranslationStage - .handle_pipeline_event() -- Called with IncomingMQTTMessageEvent (Twin Response topic string)" ) -class TestIotHubMQTTConverterHandlePipelineEventTwinPatch(IoTHubMQTTTranslationStageTestBase): +class TestIoTHubMQTTTranslationStageHandlePipelineEventWithIncomingMQTTMessageEventTwinResponseTopic( + StageHandlePipelineEventTestBase, IoTHubMQTTTranslationStageTestConfig +): @pytest.fixture - def fake_topic_name(self): - return "$iothub/twin/PATCH/properties/desired" + def status(self): + return 200 @pytest.fixture - def fake_patch(self): - return {"__fake_patch__": "yes"} + def rid(self): + return "d9d7ce4d-3be9-498b-abde-913b81b880e5" @pytest.fixture - def fake_patch_as_bytes(self, fake_patch): - return json.dumps(fake_patch).encode("utf-8") + def event(self, status, rid): + topic = "$iothub/twin/res/{status}/?$rid={rid}".format(status=status, rid=rid) + return pipeline_events_mqtt.IncomingMQTTMessageEvent(topic=topic, payload=b"some_payload") - @pytest.fixture - def fake_patch_not_bytes(self): - return "__fake_patch_that_is_not_bytes__" + @pytest.mark.it( + "Sends a ResponseEvent up the pipeline containing the original event's payload, and values extracted from the topic string" + ) + def test_response_event(self, event, stage, status, rid): + stage.handle_pipeline_event(event) - @pytest.fixture - def fake_patch_not_json(self): - return "__fake_patch_that_is_not_json__".encode("utf-8") + assert stage.send_event_up.call_count == 1 + new_event = stage.send_event_up.call_args[0][0] + assert isinstance(new_event, pipeline_events_base.ResponseEvent) + assert new_event.status_code == status + assert new_event.request_id == rid + assert new_event.response_body == event.payload + +@pytest.mark.describe( + "IoTHubMQTTTranslationStage - .handle_pipeline_event() -- Called with IncomingMQTTMessageEvent (Twin Desired Properties Patch topic string)" +) +class TestIoTHubMQTTTranslationStageHandlePipelineEventWithIncomingMQTTMessageEventTwinDesiredPropertiesPatchTopic( + StageHandlePipelineEventTestBase, IoTHubMQTTTranslationStageTestConfig +): @pytest.fixture - def fake_event(self, fake_topic_name, fake_patch_as_bytes): + def event(self): + topic = "$iothub/twin/PATCH/properties/desired/?$version=1" + # payload will be overwritten in relevant tests return pipeline_events_mqtt.IncomingMQTTMessageEvent( - topic=fake_topic_name, payload=fake_patch_as_bytes + topic=topic, payload=b'{"some": "payload"}' ) - @pytest.fixture - def fixup_stage_for_test(self, stage, add_pipeline_root): - print("Adding module") - stage.module_id = fake_module_id - stage.device_id = fake_device_id - @pytest.mark.it( - "Calls .handle_pipeline_event() on the previous stage with an TwinDesiredPropertiesPatchEvent, with the patch set to the payload after decoding and deserializing it" + "Sends a TwinDesiredPropertiesPatchEvent up the pipeline, containing the original event's payload formatted as a JSON-object" ) - def test_calls_previous_stage(self, stage, fixup_stage_for_test, fake_event, fake_patch): - stage.handle_pipeline_event(fake_event) - assert stage.previous.handle_pipeline_event.call_count == 1 - new_event = stage.previous.handle_pipeline_event.call_args[0][0] + @pytest.mark.parametrize( + "original_payload, derived_payload", + [ + pytest.param(b'{"some": "payload"}', {"some": "payload"}, id="Dictionary JSON"), + pytest.param(b'"payload"', "payload", id="String JSON"), + pytest.param(b"1234", 1234, id="Int JSON"), + pytest.param(b"null", None, id="None JSON"), + ], + ) + def test_twin_patch_event(self, event, stage, original_payload, derived_payload): + event.payload = original_payload + stage.handle_pipeline_event(event) + + assert stage.send_event_up.call_count == 1 + new_event = stage.send_event_up.call_args[0][0] assert isinstance(new_event, pipeline_events_iothub.TwinDesiredPropertiesPatchEvent) - assert new_event.patch == fake_patch + assert new_event.patch == derived_payload - @pytest.mark.it( - "Calls the unhandled exception handler with a PipelineError if there is no previous stage" - ) - def test_no_previous_stage( - self, stage, fixup_stage_for_test, fake_event, unhandled_error_handler - ): - stage.previous = None - stage.handle_pipeline_event(fake_event) - assert unhandled_error_handler.call_count == 1 - assert isinstance(unhandled_error_handler.call_args[0][0], PipelineError) - - @pytest.mark.it("Calls the unhandled exception handler if the payload is not a Bytes object") - def test_payload_not_bytes( - self, stage, fixup_stage_for_test, fake_event, fake_patch_not_bytes, unhandled_error_handler - ): - fake_event.payload = fake_patch_not_bytes - stage.handle_pipeline_event(fake_event) - assert unhandled_error_handler.call_count == 1 - if not ( - isinstance(unhandled_error_handler.call_args[0][0], AttributeError) - or isinstance(unhandled_error_handler.call_args[0][0], ValueError) - ): - assert False - @pytest.mark.it( - "Calls the unhandled exception handler if the payload cannot be deserialized as a JSON object" - ) - def test_payload_not_json( - self, stage, fixup_stage_for_test, fake_event, fake_patch_not_json, unhandled_error_handler - ): - fake_event.payload = fake_patch_not_json - stage.handle_pipeline_event(fake_event) - assert unhandled_error_handler.call_count == 1 - assert isinstance(unhandled_error_handler.call_args[0][0], ValueError) +@pytest.mark.describe( + "IoTHubMQTTTranslationStage - .handle_pipeline_event() -- Called with IncomingMQTTMessageEvent (Unrecognized topic string)" +) +class TestIoTHubMQTTTranslationStageHandlePipelineEventWithIncomingMQTTMessageEventUnknownTopicString( + StageHandlePipelineEventTestBase, IoTHubMQTTTranslationStageTestConfig +): + @pytest.fixture + def event(self): + topic = "not a real topic" + return pipeline_events_mqtt.IncomingMQTTMessageEvent(topic=topic, payload=b"some payload") + + @pytest.mark.it("Sends the event up the pipeline") + def test_sends_up(self, event, stage): + stage.handle_pipeline_event(event) + + assert stage.send_event_up.call_count == 1 + assert stage.send_event_up.call_args[0][0] == event + + +@pytest.mark.describe( + "IoTHubMQTTTranslationStage - .handle_pipeline_event() -- Called with other arbitrary event" +) +class TestIoTHubMQTTTranslationStageHandlePipelineEventWithArbitraryEvent( + StageHandlePipelineEventTestBase, IoTHubMQTTTranslationStageTestConfig +): + @pytest.fixture + def event(self, arbitrary_event): + return arbitrary_event + + @pytest.mark.it("Sends the event up the pipeline") + def test_sends_up(self, event, stage): + stage.handle_pipeline_event(event) + + assert stage.send_event_up.call_count == 1 + assert stage.send_event_up.call_args[0][0] == event diff --git a/azure-iot-device/tests/iothub/shared_client_tests.py b/azure-iot-device/tests/iothub/shared_client_tests.py new file mode 100644 index 000000000..66e92401a --- /dev/null +++ b/azure-iot-device/tests/iothub/shared_client_tests.py @@ -0,0 +1,1076 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +"""This module contains tests that are shared between sync/async clients +i.e. tests for things defined in abstract clients""" + +import pytest +import logging +import os +import io +import six +import socks + +from azure.iot.device.common import auth +from azure.iot.device.common.auth import sastoken as st +from azure.iot.device.common.auth import connection_string as cs +from azure.iot.device.iothub.pipeline import IoTHubPipelineConfig +from azure.iot.device.iothub import edge_hsm +from azure.iot.device import ProxyOptions + +logging.basicConfig(level=logging.DEBUG) + + +################################ +# SHARED DEVICE + MODULE TESTS # +################################ + + +class SharedIoTHubClientInstantiationTests(object): + @pytest.mark.it( + "Stores the MQTTPipeline from the 'mqtt_pipeline' parameter in the '_mqtt_pipeline' attribute" + ) + def test_mqtt_pipeline_attribute(self, client_class, mqtt_pipeline, http_pipeline): + client = client_class(mqtt_pipeline, http_pipeline) + + assert client._mqtt_pipeline is mqtt_pipeline + + @pytest.mark.it( + "Stores the HTTPPipeline from the 'http_pipeline' parameter in the '_http_pipeline' attribute" + ) + def test_sets_http_pipeline_attribute(self, client_class, mqtt_pipeline, http_pipeline): + client = client_class(mqtt_pipeline, http_pipeline) + + assert client._http_pipeline is http_pipeline + + @pytest.mark.it("Sets on_connected handler in the MQTTPipeline") + def test_sets_on_connected_handler_in_pipeline( + self, client_class, mqtt_pipeline, http_pipeline + ): + client = client_class(mqtt_pipeline, http_pipeline) + + assert client._mqtt_pipeline.on_connected is not None + assert client._mqtt_pipeline.on_connected == client._on_connected + + @pytest.mark.it("Sets on_disconnected handler in the MQTTPipeline") + def test_sets_on_disconnected_handler_in_pipeline( + self, client_class, mqtt_pipeline, http_pipeline + ): + client = client_class(mqtt_pipeline, http_pipeline) + + assert client._mqtt_pipeline.on_disconnected is not None + assert client._mqtt_pipeline.on_disconnected == client._on_disconnected + + @pytest.mark.it("Sets on_method_request_received handler in the MQTTPipeline") + def test_sets_on_method_request_received_handler_in_pipleline( + self, client_class, mqtt_pipeline, http_pipeline + ): + client = client_class(mqtt_pipeline, http_pipeline) + + assert client._mqtt_pipeline.on_method_request_received is not None + assert ( + client._mqtt_pipeline.on_method_request_received + == client._inbox_manager.route_method_request + ) + + +@pytest.mark.usefixtures("mock_mqtt_pipeline_init", "mock_http_pipeline_init") +class SharedIoTHubClientCreateMethodUserOptionTests(object): + @pytest.fixture + def option_test_required_patching(self, mocker): + """Override this fixture in a subclass if unique patching is required""" + pass + + @pytest.mark.it( + "Sets the 'product_info' user option parameter on the PipelineConfig, if provided" + ) + def test_product_info_option( + self, + option_test_required_patching, + client_create_method, + create_method_args, + mock_mqtt_pipeline_init, + mock_http_pipeline_init, + ): + + product_info = "MyProductInfo" + client_create_method(*create_method_args, product_info=product_info) + + # Get configuration object, and ensure it was used for both protocol pipelines + assert mock_mqtt_pipeline_init.call_count == 1 + config = mock_mqtt_pipeline_init.call_args[0][0] + assert isinstance(config, IoTHubPipelineConfig) + assert config == mock_http_pipeline_init.call_args[0][0] + + assert config.product_info == product_info + + @pytest.mark.it( + "Sets the 'websockets' user option parameter on the PipelineConfig, if provided" + ) + def test_websockets_option( + self, + option_test_required_patching, + client_create_method, + create_method_args, + mock_mqtt_pipeline_init, + mock_http_pipeline_init, + ): + + client_create_method(*create_method_args, websockets=True) + + # Get configuration object, and ensure it was used for both protocol pipelines + assert mock_mqtt_pipeline_init.call_count == 1 + config = mock_mqtt_pipeline_init.call_args[0][0] + assert isinstance(config, IoTHubPipelineConfig) + assert config == mock_http_pipeline_init.call_args[0][0] + + assert config.websockets + + # TODO: Show that input in the wrong format is formatted to the correct one. This test exists + # in the IoTHubPipelineConfig object already, but we do not currently show that this is felt + # from the API level. + @pytest.mark.it("Sets the 'cipher' user option parameter on the PipelineConfig, if provided") + def test_cipher_option( + self, + option_test_required_patching, + client_create_method, + create_method_args, + mock_mqtt_pipeline_init, + mock_http_pipeline_init, + ): + cipher = "DHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA:ECDHE-ECDSA-AES128-GCM-SHA256" + client_create_method(*create_method_args, cipher=cipher) + + # Get configuration object, and ensure it was used for both protocol pipelines + assert mock_mqtt_pipeline_init.call_count == 1 + config = mock_mqtt_pipeline_init.call_args[0][0] + assert isinstance(config, IoTHubPipelineConfig) + assert config == mock_http_pipeline_init.call_args[0][0] + + assert config.cipher == cipher + + @pytest.mark.it( + "Sets the 'server_verification_cert' user option parameter on the PipelineConfig, if provided" + ) + def test_server_verification_cert_option( + self, + option_test_required_patching, + client_create_method, + create_method_args, + mock_mqtt_pipeline_init, + mock_http_pipeline_init, + ): + server_verification_cert = "fake_server_verification_cert" + client_create_method(*create_method_args, server_verification_cert=server_verification_cert) + + # Get configuration object, and ensure it was used for both protocol pipelines + assert mock_mqtt_pipeline_init.call_count == 1 + config = mock_mqtt_pipeline_init.call_args[0][0] + assert isinstance(config, IoTHubPipelineConfig) + assert config == mock_http_pipeline_init.call_args[0][0] + + assert config.server_verification_cert == server_verification_cert + + @pytest.mark.it( + "Sets the 'proxy_options' user option parameter on the PipelineConfig, if provided" + ) + def test_proxy_options( + self, + option_test_required_patching, + client_create_method, + create_method_args, + mock_mqtt_pipeline_init, + mock_http_pipeline_init, + ): + proxy_options = ProxyOptions(proxy_type=socks.HTTP, proxy_addr="127.0.0.1", proxy_port=8888) + client_create_method(*create_method_args, proxy_options=proxy_options) + + # Get configuration object, and ensure it was used for both protocol pipelines + assert mock_mqtt_pipeline_init.call_count == 1 + config = mock_mqtt_pipeline_init.call_args[0][0] + assert isinstance(config, IoTHubPipelineConfig) + assert config == mock_http_pipeline_init.call_args[0][0] + + assert config.proxy_options is proxy_options + + @pytest.mark.it("Raises a TypeError if an invalid user option parameter is provided") + def test_invalid_option( + self, option_test_required_patching, client_create_method, create_method_args + ): + with pytest.raises(TypeError): + client_create_method(*create_method_args, invalid_option="some_value") + + @pytest.mark.it("Sets default user options if none are provided") + def test_default_options( + self, + mocker, + option_test_required_patching, + client_create_method, + create_method_args, + mock_mqtt_pipeline_init, + mock_http_pipeline_init, + ): + client_create_method(*create_method_args) + + # Both pipelines use the same IoTHubPipelineConfig + assert mock_mqtt_pipeline_init.call_count == 1 + assert mock_http_pipeline_init.call_count == 1 + assert mock_mqtt_pipeline_init.call_args[0][0] is mock_http_pipeline_init.call_args[0][0] + config = mock_mqtt_pipeline_init.call_args[0][0] + assert isinstance(config, IoTHubPipelineConfig) + + # Pipeline Config has default options set that were not user-specified + assert config.product_info == "" + assert config.websockets is False + assert config.cipher == "" + assert config.proxy_options is None + assert config.server_verification_cert is None + + +# TODO: consider splitting this test class up into device/module specific test classes to avoid +# the conditional logic in some tests +@pytest.mark.usefixtures("mock_mqtt_pipeline_init", "mock_http_pipeline_init") +class SharedIoTHubClientCreateFromConnectionStringTests( + SharedIoTHubClientCreateMethodUserOptionTests +): + @pytest.fixture + def client_create_method(self, client_class): + """Provides the specific create method for use in universal tests""" + return client_class.create_from_connection_string + + @pytest.fixture + def create_method_args(self, connection_string): + """Provides the specific create method args for use in universal tests""" + return [connection_string] + + @pytest.mark.it( + "Creates a SasToken that uses a SymmetricKeySigningMechanism, from the values in the provided connection string" + ) + def test_sastoken(self, mocker, client_class, connection_string): + sksm_mock = mocker.patch.object(auth, "SymmetricKeySigningMechanism") + sastoken_mock = mocker.patch.object(st, "SasToken") + cs_obj = cs.ConnectionString(connection_string) + + client_class.create_from_connection_string(connection_string) + + # Determine expected URI based on class under test + if client_class.__name__ == "IoTHubDeviceClient": + expected_uri = "{hostname}/devices/{device_id}".format( + hostname=cs_obj[cs.HOST_NAME], device_id=cs_obj[cs.DEVICE_ID] + ) + else: + expected_uri = "{hostname}/devices/{device_id}/modules/{module_id}".format( + hostname=cs_obj[cs.HOST_NAME], + device_id=cs_obj[cs.DEVICE_ID], + module_id=cs_obj[cs.MODULE_ID], + ) + + # SymmetricKeySigningMechanism created using the connection string's SharedAccessKey + assert sksm_mock.call_count == 1 + assert sksm_mock.call_args == mocker.call(key=cs_obj[cs.SHARED_ACCESS_KEY]) + + # Token was created with a SymmetricKeySigningMechanism and the expected URI + assert sastoken_mock.call_count == 1 + assert sastoken_mock.call_args == mocker.call(expected_uri, sksm_mock.return_value) + + @pytest.mark.it( + "Creates MQTT and HTTP Pipelines with an IoTHubPipelineConfig object containing the SasToken and values from the connection string" + ) + def test_pipeline_config( + self, + mocker, + client_class, + connection_string, + mock_mqtt_pipeline_init, + mock_http_pipeline_init, + ): + sastoken_mock = mocker.patch.object(st, "SasToken") + cs_obj = cs.ConnectionString(connection_string) + + client_class.create_from_connection_string(connection_string) + + # Verify pipelines created with an IoTHubPipelineConfig + assert mock_mqtt_pipeline_init.call_count == 1 + assert mock_http_pipeline_init.call_count == 1 + assert mock_mqtt_pipeline_init.call_args[0][0] is mock_http_pipeline_init.call_args[0][0] + assert isinstance(mock_mqtt_pipeline_init.call_args[0][0], IoTHubPipelineConfig) + + # Verify the IoTHubPipelineConfig is constructed as expected + config = mock_mqtt_pipeline_init.call_args[0][0] + assert config.device_id == cs_obj[cs.DEVICE_ID] + assert config.hostname == cs_obj[cs.HOST_NAME] + assert config.sastoken is sastoken_mock.return_value + if client_class.__name__ == "IoTHubModuleClient": + assert config.module_id == cs_obj[cs.MODULE_ID] + assert config.blob_upload is False + assert config.method_invoke is False + else: + assert config.module_id is None + assert config.blob_upload is True + assert config.method_invoke is False + if cs_obj.get(cs.GATEWAY_HOST_NAME): + assert config.gateway_hostname == cs_obj[cs.GATEWAY_HOST_NAME] + else: + assert config.gateway_hostname is None + + @pytest.mark.it( + "Returns an instance of an IoTHub client using the created MQTT and HTTP pipelines" + ) + def test_client_returned( + self, + mocker, + client_class, + connection_string, + mock_mqtt_pipeline_init, + mock_http_pipeline_init, + ): + client = client_class.create_from_connection_string(connection_string) + assert isinstance(client, client_class) + assert client._mqtt_pipeline is mock_mqtt_pipeline_init.return_value + assert client._http_pipeline is mock_http_pipeline_init.return_value + + @pytest.mark.it("Raises ValueError when given an invalid connection string") + @pytest.mark.parametrize( + "bad_cs", + [ + pytest.param("not-a-connection-string", id="Garbage string"), + pytest.param( + "HostName=value.domain.net;DeviceId=my_device;SharedAccessKey=Invalid", + id="Shared Access Key invalid", + ), + pytest.param( + "HostName=value.domain.net;WrongValue=Invalid;SharedAccessKey=Zm9vYmFy", + id="Contains extraneous data", + ), + pytest.param("HostName=value.domain.net;DeviceId=my_device", id="Incomplete"), + ], + ) + def test_raises_value_error_on_bad_connection_string(self, client_class, bad_cs): + with pytest.raises(ValueError): + client_class.create_from_connection_string(bad_cs) + + @pytest.mark.it("Raises ValueError if a SasToken creation results in failure") + def test_raises_value_error_on_sastoken_failure(self, mocker, client_class, connection_string): + sastoken_mock = mocker.patch.object(st, "SasToken") + token_err = st.SasTokenError("Some SasToken failure") + sastoken_mock.side_effect = token_err + + with pytest.raises(ValueError) as e_info: + client_class.create_from_connection_string(connection_string) + assert e_info.value.__cause__ is token_err + + +# NOTE: If more properties are added, this class should become a general purpose properties testclass +class SharedIoTHubClientPROPERTYConnectedTests(object): + @pytest.mark.it("Cannot be changed") + def test_read_only(self, client): + with pytest.raises(AttributeError): + client.connected = not client.connected + + @pytest.mark.it("Reflects the value of the root stage property of the same name") + def test_reflects_pipeline_property(self, client, mqtt_pipeline): + mqtt_pipeline.connected = True + assert client.connected + mqtt_pipeline.connected = False + assert not client.connected + + +############################## +# SHARED DEVICE CLIENT TESTS # +############################## + + +@pytest.mark.usefixtures("mock_mqtt_pipeline_init", "mock_http_pipeline_init") +class SharedIoTHubDeviceClientCreateFromSymmetricKeyTests( + SharedIoTHubClientCreateMethodUserOptionTests +): + hostname = "durmstranginstitute.farend" + device_id = "MySnitch" + symmetric_key = "Zm9vYmFy" + + @pytest.fixture + def client_create_method(self, client_class): + """Provides the specific create method for use in universal tests""" + return client_class.create_from_symmetric_key + + @pytest.fixture + def create_method_args(self): + """Provides the specific create method args for use in universal tests""" + return [self.symmetric_key, self.hostname, self.device_id] + + @pytest.mark.it( + "Creates a SasToken that uses a SymmetricKeySigningMechanism, from the values provided in parameters" + ) + def test_sastoken(self, mocker, client_class): + sksm_mock = mocker.patch.object(auth, "SymmetricKeySigningMechanism") + sastoken_mock = mocker.patch.object(st, "SasToken") + expected_uri = "{hostname}/devices/{device_id}".format( + hostname=self.hostname, device_id=self.device_id + ) + + client_class.create_from_symmetric_key( + symmetric_key=self.symmetric_key, hostname=self.hostname, device_id=self.device_id + ) + + # SymmetricKeySigningMechanism created using the provided symmetric key + assert sksm_mock.call_count == 1 + assert sksm_mock.call_args == mocker.call(key=self.symmetric_key) + + # SasToken created with the SymmetricKeySigningMechanism and the expected URI + assert sastoken_mock.call_count == 1 + assert sastoken_mock.call_args == mocker.call(expected_uri, sksm_mock.return_value) + + @pytest.mark.it( + "Creates MQTT and HTTP pipelines with an IoTHubPipelineConfig object containing the SasToken and values provided in parameters" + ) + def test_pipeline_config( + self, mocker, client_class, mock_mqtt_pipeline_init, mock_http_pipeline_init + ): + sastoken_mock = mocker.patch.object(st, "SasToken") + + client_class.create_from_symmetric_key( + symmetric_key=self.symmetric_key, hostname=self.hostname, device_id=self.device_id + ) + + # Verify pipelines created with an IoTHubPipelineConfig + assert mock_mqtt_pipeline_init.call_count == 1 + assert mock_http_pipeline_init.call_count == 1 + assert mock_mqtt_pipeline_init.call_args[0][0] is mock_http_pipeline_init.call_args[0][0] + assert isinstance(mock_mqtt_pipeline_init.call_args[0][0], IoTHubPipelineConfig) + + # Verify the IoTHubPipelineConfig is constructed as expected + config = mock_mqtt_pipeline_init.call_args[0][0] + assert config.device_id == self.device_id + assert config.hostname == self.hostname + assert config.gateway_hostname is None + assert config.sastoken is sastoken_mock.return_value + assert config.blob_upload is True + assert config.method_invoke is False + + @pytest.mark.it( + "Returns an instance of an IoTHubDeviceClient using the created MQTT and HTTP pipelines" + ) + def test_client_returned( + self, mocker, client_class, mock_mqtt_pipeline_init, mock_http_pipeline_init + ): + client = client_class.create_from_symmetric_key( + symmetric_key=self.symmetric_key, hostname=self.hostname, device_id=self.device_id + ) + assert isinstance(client, client_class) + assert client._mqtt_pipeline is mock_mqtt_pipeline_init.return_value + assert client._http_pipeline is mock_http_pipeline_init.return_value + + @pytest.mark.it("Raises ValueError if a SasToken creation results in failure") + def test_raises_value_error_on_sastoken_failure(self, mocker, client_class): + sastoken_mock = mocker.patch.object(st, "SasToken") + token_err = st.SasTokenError("Some SasToken failure") + sastoken_mock.side_effect = token_err + + with pytest.raises(ValueError) as e_info: + client_class.create_from_symmetric_key( + symmetric_key=self.symmetric_key, hostname=self.hostname, device_id=self.device_id + ) + assert e_info.value.__cause__ is token_err + + +@pytest.mark.usefixtures("mock_mqtt_pipeline_init", "mock_http_pipeline_init") +class SharedIoTHubDeviceClientCreateFromX509CertificateTests( + SharedIoTHubClientCreateMethodUserOptionTests +): + hostname = "durmstranginstitute.farend" + device_id = "MySnitch" + + @pytest.fixture + def client_create_method(self, client_class): + """Provides the specific create method for use in universal tests""" + return client_class.create_from_x509_certificate + + @pytest.fixture + def create_method_args(self, x509): + """Provides the specific create method args for use in universal tests""" + return [x509, self.hostname, self.device_id] + + @pytest.mark.it( + "Creates MQTT and HTTP pipelines with an IoTHubPipelineConfig object containing the X509 and other values provided in parameters" + ) + def test_pipeline_config( + self, mocker, client_class, x509, mock_mqtt_pipeline_init, mock_http_pipeline_init + ): + client_class.create_from_x509_certificate( + x509=x509, hostname=self.hostname, device_id=self.device_id + ) + + # Verify pipelines created with an IoTHubPipelineConfig + assert mock_mqtt_pipeline_init.call_count == 1 + assert mock_http_pipeline_init.call_count == 1 + assert mock_mqtt_pipeline_init.call_args[0][0] == mock_http_pipeline_init.call_args[0][0] + assert isinstance(mock_mqtt_pipeline_init.call_args[0][0], IoTHubPipelineConfig) + + # Verify the IoTHubPipelineConfig is constructed as expected + config = mock_mqtt_pipeline_init.call_args[0][0] + assert config.device_id == self.device_id + assert config.hostname == self.hostname + assert config.gateway_hostname is None + assert config.x509 is x509 + assert config.blob_upload is True + assert config.method_invoke is False + + @pytest.mark.it( + "Returns an instance of an IoTHubDeviceclient using the created MQTT and HTTP pipelines" + ) + def test_client_returned( + self, mocker, client_class, x509, mock_mqtt_pipeline_init, mock_http_pipeline_init + ): + client = client_class.create_from_x509_certificate( + x509=x509, hostname=self.hostname, device_id=self.device_id + ) + assert isinstance(client, client_class) + assert client._mqtt_pipeline is mock_mqtt_pipeline_init.return_value + assert client._http_pipeline is mock_http_pipeline_init.return_value + + +############################## +# SHARED MODULE CLIENT TESTS # +############################## + + +@pytest.mark.usefixtures("mock_mqtt_pipeline_init", "mock_http_pipeline_init") +class SharedIoTHubModuleClientCreateFromX509CertificateTests( + SharedIoTHubClientCreateMethodUserOptionTests +): + hostname = "durmstranginstitute.farend" + device_id = "MySnitch" + module_id = "Charms" + + @pytest.fixture + def client_create_method(self, client_class): + """Provides the specific create method for use in universal tests""" + return client_class.create_from_x509_certificate + + @pytest.fixture + def create_method_args(self, x509): + """Provides the specific create method args for use in universal tests""" + return [x509, self.hostname, self.device_id, self.module_id] + + @pytest.mark.it( + "Creates MQTT and HTTP pipelines with an IoTHubPipelineConfig object containing the X509 and other values provided in parameters" + ) + def test_pipeline_config( + self, mocker, client_class, x509, mock_mqtt_pipeline_init, mock_http_pipeline_init + ): + client_class.create_from_x509_certificate( + x509=x509, hostname=self.hostname, device_id=self.device_id, module_id=self.module_id + ) + + # Verify pipelines created with an IoTHubPipelineConfig + assert mock_mqtt_pipeline_init.call_count == 1 + assert mock_http_pipeline_init.call_count == 1 + assert mock_mqtt_pipeline_init.call_args[0][0] == mock_http_pipeline_init.call_args[0][0] + assert isinstance(mock_mqtt_pipeline_init.call_args[0][0], IoTHubPipelineConfig) + + # Verify the IoTHubPipelineConfig is constructed as expected + config = mock_mqtt_pipeline_init.call_args[0][0] + assert config.device_id == self.device_id + assert config.hostname == self.hostname + assert config.gateway_hostname is None + assert config.x509 is x509 + assert config.blob_upload is False + assert config.method_invoke is False + + @pytest.mark.it( + "Returns an instance of an IoTHubDeviceclient using the created MQTT and HTTP pipelines" + ) + def test_client_returned( + self, mocker, client_class, x509, mock_mqtt_pipeline_init, mock_http_pipeline_init + ): + client = client_class.create_from_x509_certificate( + x509=x509, hostname=self.hostname, device_id=self.device_id, module_id=self.module_id + ) + assert isinstance(client, client_class) + assert client._mqtt_pipeline is mock_mqtt_pipeline_init.return_value + assert client._http_pipeline is mock_http_pipeline_init.return_value + + +@pytest.mark.usefixtures("mock_mqtt_pipeline_init", "mock_http_pipeline_init") +class SharedIoTHubModuleClientClientCreateFromEdgeEnvironmentUserOptionTests( + SharedIoTHubClientCreateMethodUserOptionTests +): + """This class inherites the user option tests shared by all create method APIs, and overrides + tests in order to accomodate unique requirements for the .create_from_edge_enviornment() method. + + Because .create_from_edge_environment() tests are spread accross multiple test units + (i.e. test classes), these overrides are done in this class, which is then inherited by all + .create_from_edge_environment() test units below. + """ + + @pytest.fixture + def client_create_method(self, client_class): + """Provides the specific create method for use in universal tests""" + return client_class.create_from_edge_environment + + @pytest.fixture + def create_method_args(self): + """Provides the specific create method args for use in universal tests""" + return [] + + @pytest.mark.it( + "Raises a TypeError if the 'server_verification_cert' user option parameter is provided" + ) + def test_server_verification_cert_option( + self, + option_test_required_patching, + client_create_method, + create_method_args, + mock_mqtt_pipeline_init, + mock_http_pipeline_init, + ): + """THIS TEST OVERRIDES AN INHERITED TEST""" + # Override to test that server_verification_cert CANNOT be provided in Edge scenarios + + with pytest.raises(TypeError): + client_create_method( + *create_method_args, server_verification_cert="fake_server_verification_cert" + ) + + @pytest.mark.it("Sets default user options if none are provided") + def test_default_options( + self, + mocker, + option_test_required_patching, + client_create_method, + create_method_args, + mock_mqtt_pipeline_init, + mock_http_pipeline_init, + ): + """THIS TEST OVERRIDES AN INHERITED TEST""" + # Override so that can avoid the check on server_verification_cert being None + # as in Edge scenarios, it is not None + + client_create_method(*create_method_args) + + # Both pipelines use the same IoTHubPipelineConfig + assert mock_mqtt_pipeline_init.call_count == 1 + assert mock_http_pipeline_init.call_count == 1 + assert mock_mqtt_pipeline_init.call_args[0][0] is mock_http_pipeline_init.call_args[0][0] + config = mock_mqtt_pipeline_init.call_args[0][0] + assert isinstance(config, IoTHubPipelineConfig) + + # Pipeline Config has default options that were not specified + assert config.product_info == "" + assert config.websockets is False + assert config.cipher == "" + assert config.proxy_options is None + + +@pytest.mark.usefixtures("mock_mqtt_pipeline_init", "mock_http_pipeline_init") +class SharedIoTHubModuleClientCreateFromEdgeEnvironmentWithContainerEnvTests( + SharedIoTHubModuleClientClientCreateFromEdgeEnvironmentUserOptionTests +): + @pytest.fixture + def option_test_required_patching(self, mocker, mock_edge_hsm, edge_container_environment): + """THIS FIXTURE OVERRIDES AN INHERITED FIXTURE""" + mocker.patch.dict(os.environ, edge_container_environment, clear=True) + + @pytest.mark.it( + "Creates a SasToken that uses an IoTEdgeHsm, from the values extracted from the Edge environment" + ) + def test_sastoken(self, mocker, client_class, mock_edge_hsm, edge_container_environment): + mocker.patch.dict(os.environ, edge_container_environment, clear=True) + sastoken_mock = mocker.patch.object(st, "SasToken") + + expected_uri = "{hostname}/devices/{device_id}/modules/{module_id}".format( + hostname=edge_container_environment["IOTEDGE_IOTHUBHOSTNAME"], + device_id=edge_container_environment["IOTEDGE_DEVICEID"], + module_id=edge_container_environment["IOTEDGE_MODULEID"], + ) + + client_class.create_from_edge_environment() + + # IoTEdgeHsm created using the extracted values + assert mock_edge_hsm.call_count == 1 + assert mock_edge_hsm.call_args == mocker.call( + module_id=edge_container_environment["IOTEDGE_MODULEID"], + generation_id=edge_container_environment["IOTEDGE_MODULEGENERATIONID"], + workload_uri=edge_container_environment["IOTEDGE_WORKLOADURI"], + api_version=edge_container_environment["IOTEDGE_APIVERSION"], + ) + + # SasToken created with the IoTEdgeHsm and the expected URI + assert sastoken_mock.call_count == 1 + assert sastoken_mock.call_args == mocker.call(expected_uri, mock_edge_hsm.return_value) + + @pytest.mark.it( + "Uses an IoTEdgeHsm as the SasToken signing mechanism even if any Edge local debug environment variables may also be present" + ) + def test_hybrid_env( + self, + mocker, + client_class, + mock_edge_hsm, + edge_container_environment, + edge_local_debug_environment, + ): + hybrid_environment = merge_dicts(edge_container_environment, edge_local_debug_environment) + mocker.patch.dict(os.environ, hybrid_environment, clear=True) + sastoken_mock = mocker.patch.object(st, "SasToken") + mock_sksm = mocker.patch.object(auth, "SymmetricKeySigningMechanism") + + client_class.create_from_edge_environment() + + assert mock_sksm.call_count == 0 # we did NOT use SK signing mechanism + assert mock_edge_hsm.call_count == 1 # instead, we still used edge hsm + assert mock_edge_hsm.call_args == mocker.call( + module_id=edge_container_environment["IOTEDGE_MODULEID"], + generation_id=edge_container_environment["IOTEDGE_MODULEGENERATIONID"], + workload_uri=edge_container_environment["IOTEDGE_WORKLOADURI"], + api_version=edge_container_environment["IOTEDGE_APIVERSION"], + ) + assert sastoken_mock.call_count == 1 + assert sastoken_mock.call_args == mocker.call(mocker.ANY, mock_edge_hsm.return_value) + + @pytest.mark.it( + "Creates MQTT and HTTP pipelines with an IoTHubPipelineConfig object containing the SasToken and values extracted from the Edge environment" + ) + def test_pipeline_config( + self, + mocker, + client_class, + mock_edge_hsm, + edge_container_environment, + mock_mqtt_pipeline_init, + mock_http_pipeline_init, + ): + mocker.patch.dict(os.environ, edge_container_environment, clear=True) + sastoken_mock = mocker.patch.object(st, "SasToken") + + client_class.create_from_edge_environment() + + # Verify pipelines created with an IoTHubPipelineConfig + assert mock_mqtt_pipeline_init.call_count == 1 + assert mock_http_pipeline_init.call_count == 1 + assert mock_mqtt_pipeline_init.call_args[0][0] is mock_http_pipeline_init.call_args[0][0] + assert isinstance(mock_mqtt_pipeline_init.call_args[0][0], IoTHubPipelineConfig) + + # Verify the IoTHubPipelineConfig is constructed as expected + config = mock_mqtt_pipeline_init.call_args[0][0] + assert config.device_id == edge_container_environment["IOTEDGE_DEVICEID"] + assert config.module_id == edge_container_environment["IOTEDGE_MODULEID"] + assert config.hostname == edge_container_environment["IOTEDGE_IOTHUBHOSTNAME"] + assert config.gateway_hostname == edge_container_environment["IOTEDGE_GATEWAYHOSTNAME"] + assert config.sastoken is sastoken_mock.return_value + assert ( + config.server_verification_cert + == mock_edge_hsm.return_value.get_certificate.return_value + ) + assert config.method_invoke is True + assert config.blob_upload is False + + @pytest.mark.it( + "Returns an instance of an IoTHubModuleClient using the created MQTT and HTTP pipelines" + ) + def test_client_returns( + self, + mocker, + client_class, + mock_edge_hsm, + edge_container_environment, + mock_mqtt_pipeline_init, + mock_http_pipeline_init, + ): + mocker.patch.dict(os.environ, edge_container_environment, clear=True) + + client = client_class.create_from_edge_environment() + assert isinstance(client, client_class) + assert client._mqtt_pipeline is mock_mqtt_pipeline_init.return_value + assert client._http_pipeline is mock_http_pipeline_init.return_value + + @pytest.mark.it("Raises OSError if the environment is missing required variables") + @pytest.mark.parametrize( + "missing_env_var", + [ + "IOTEDGE_MODULEID", + "IOTEDGE_DEVICEID", + "IOTEDGE_IOTHUBHOSTNAME", + "IOTEDGE_GATEWAYHOSTNAME", + "IOTEDGE_APIVERSION", + "IOTEDGE_MODULEGENERATIONID", + "IOTEDGE_WORKLOADURI", + ], + ) + def test_bad_environment( + self, mocker, client_class, edge_container_environment, missing_env_var + ): + # Remove a variable from the fixture + del edge_container_environment[missing_env_var] + mocker.patch.dict(os.environ, edge_container_environment, clear=True) + + with pytest.raises(OSError): + client_class.create_from_edge_environment() + + @pytest.mark.it( + "Raises OSError if there is an error retrieving the server verification certificate from Edge with the IoTEdgeHsm" + ) + def test_bad_edge_auth(self, mocker, client_class, edge_container_environment, mock_edge_hsm): + mocker.patch.dict(os.environ, edge_container_environment, clear=True) + my_edge_error = edge_hsm.IoTEdgeError() + mock_edge_hsm.return_value.get_certificate.side_effect = my_edge_error + + with pytest.raises(OSError) as e_info: + client_class.create_from_edge_environment() + assert e_info.value.__cause__ is my_edge_error + + @pytest.mark.it("Raises ValueError if a SasToken creation results in failure") + def test_raises_value_error_on_sastoken_failure( + self, mocker, client_class, edge_container_environment, mock_edge_hsm + ): + mocker.patch.dict(os.environ, edge_container_environment, clear=True) + sastoken_mock = mocker.patch.object(st, "SasToken") + token_err = st.SasTokenError("Some SasToken failure") + sastoken_mock.side_effect = token_err + + with pytest.raises(ValueError) as e_info: + client_class.create_from_edge_environment() + assert e_info.value.__cause__ is token_err + + +@pytest.mark.usefixtures("mock_mqtt_pipeline_init", "mock_http_pipeline_init") +class SharedIoTHubModuleClientCreateFromEdgeEnvironmentWithDebugEnvTests( + SharedIoTHubModuleClientClientCreateFromEdgeEnvironmentUserOptionTests +): + @pytest.fixture + def option_test_required_patching(self, mocker, mock_open, edge_local_debug_environment): + """THIS FIXTURE OVERRIDES AN INHERITED FIXTURE""" + mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) + + @pytest.fixture + def mock_open(self, mocker): + return mocker.patch.object(io, "open") + + @pytest.mark.it( + "Creates a SasToken that uses a SymmetricKeySigningMechanism, from the values in the connection string extracted from the Edge local debug environment" + ) + def test_sastoken(self, mocker, client_class, mock_open, edge_local_debug_environment): + mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) + sksm_mock = mocker.patch.object(auth, "SymmetricKeySigningMechanism") + sastoken_mock = mocker.patch.object(st, "SasToken") + cs_obj = cs.ConnectionString(edge_local_debug_environment["EdgeHubConnectionString"]) + expected_uri = "{hostname}/devices/{device_id}/modules/{module_id}".format( + hostname=cs_obj[cs.HOST_NAME], + device_id=cs_obj[cs.DEVICE_ID], + module_id=cs_obj[cs.MODULE_ID], + ) + + client_class.create_from_edge_environment() + + # SymmetricKeySigningMechanism created using the connection string's Shared Access Key + assert sksm_mock.call_count == 1 + assert sksm_mock.call_args == mocker.call(key=cs_obj[cs.SHARED_ACCESS_KEY]) + + # SasToken created with the SymmetricKeySigningMechanism and the expected URI + assert sastoken_mock.call_count == 1 + assert sastoken_mock.call_args == mocker.call(expected_uri, sksm_mock.return_value) + + @pytest.mark.it( + "Only uses Edge local debug variables if no Edge container variables are present in the environment" + ) + def test_auth_provider_and_pipeline_hybrid_env( + self, + mocker, + client_class, + edge_container_environment, + edge_local_debug_environment, + mock_open, + mock_edge_hsm, + ): + # This test verifies that the presence of edge container environment variables means the + # code will follow the edge container environment creation path (using the IoTEdgeHsm) + # even if edge local debug variables are present. + hybrid_environment = merge_dicts(edge_container_environment, edge_local_debug_environment) + mocker.patch.dict(os.environ, hybrid_environment, clear=True) + sastoken_mock = mocker.patch.object(st, "SasToken") + sksm_mock = mocker.patch.object(auth, "SymmetricKeySigningMechanism") + + client_class.create_from_edge_environment() + + assert sksm_mock.call_count == 0 # we did NOT use SK signing mechanism + assert mock_edge_hsm.call_count == 1 # instead, we still used edge HSM + assert mock_edge_hsm.call_args == mocker.call( + module_id=edge_container_environment["IOTEDGE_MODULEID"], + generation_id=edge_container_environment["IOTEDGE_MODULEGENERATIONID"], + workload_uri=edge_container_environment["IOTEDGE_WORKLOADURI"], + api_version=edge_container_environment["IOTEDGE_APIVERSION"], + ) + assert sastoken_mock.call_count == 1 + assert sastoken_mock.call_args == mocker.call(mocker.ANY, mock_edge_hsm.return_value) + + @pytest.mark.it( + "Extracts the server verification certificate from the file indicated by the filepath extracted from the Edge local debug environment" + ) + def test_open_ca_cert(self, mocker, client_class, edge_local_debug_environment, mock_open): + mock_file_handle = mock_open.return_value.__enter__.return_value + mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) + + client_class.create_from_edge_environment() + + assert mock_open.call_count == 1 + assert mock_open.call_args == mocker.call( + edge_local_debug_environment["EdgeModuleCACertificateFile"], mode="r" + ) + assert mock_file_handle.read.call_count == 1 + assert mock_file_handle.read.call_args == mocker.call() + + @pytest.mark.it( + "Creates MQTT and HTTP pipelines with an IoTHubPipelineConfig object containing the SasToken and values extracted from the Edge local debug environment" + ) + def test_pipeline_config( + self, + mocker, + client_class, + mock_open, + edge_local_debug_environment, + mock_mqtt_pipeline_init, + mock_http_pipeline_init, + ): + mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) + sastoken_mock = mocker.patch.object(st, "SasToken") + mock_file_handle = mock_open.return_value.__enter__.return_value + ca_cert_file_contents = "some cert" + mock_file_handle.read.return_value = ca_cert_file_contents + + cs_obj = cs.ConnectionString(edge_local_debug_environment["EdgeHubConnectionString"]) + + client_class.create_from_edge_environment() + + # Verify pipelines created with an IoTHubPipelineConfig + assert mock_mqtt_pipeline_init.call_count == 1 + assert mock_http_pipeline_init.call_count == 1 + assert mock_mqtt_pipeline_init.call_args[0][0] is mock_http_pipeline_init.call_args[0][0] + assert isinstance(mock_mqtt_pipeline_init.call_args[0][0], IoTHubPipelineConfig) + + # Verify the IoTHubPipelingConfig is constructed as expected + config = mock_mqtt_pipeline_init.call_args[0][0] + assert config.device_id == cs_obj[cs.DEVICE_ID] + assert config.module_id == cs_obj[cs.MODULE_ID] + assert config.hostname == cs_obj[cs.HOST_NAME] + assert config.gateway_hostname == cs_obj[cs.GATEWAY_HOST_NAME] + assert config.sastoken is sastoken_mock.return_value + assert config.server_verification_cert == ca_cert_file_contents + assert config.method_invoke is True + assert config.blob_upload is False + + @pytest.mark.it( + "Returns an instance of an IoTHub client using the created MQTT and HTTP pipelines" + ) + def test_client_returned( + self, + mocker, + client_class, + mock_open, + edge_local_debug_environment, + mock_mqtt_pipeline_init, + mock_http_pipeline_init, + ): + mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) + + client = client_class.create_from_edge_environment() + + assert isinstance(client, client_class) + assert client._mqtt_pipeline is mock_mqtt_pipeline_init.return_value + assert client._http_pipeline is mock_http_pipeline_init.return_value + + @pytest.mark.it("Raises OSError if the environment is missing required variables") + @pytest.mark.parametrize( + "missing_env_var", ["EdgeHubConnectionString", "EdgeModuleCACertificateFile"] + ) + def test_bad_environment( + self, mocker, client_class, edge_local_debug_environment, missing_env_var, mock_open + ): + # Remove a variable from the fixture + del edge_local_debug_environment[missing_env_var] + mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) + + with pytest.raises(OSError): + client_class.create_from_edge_environment() + + @pytest.mark.it( + "Raises ValueError if the connection string in the EdgeHubConnectionString environment variable is invalid" + ) + @pytest.mark.parametrize( + "bad_cs", + [ + pytest.param("not-a-connection-string", id="Garbage string"), + pytest.param( + "HostName=value.domain.net;DeviceId=my_device;ModuleId=my_module;SharedAccessKey=Invalid", + id="Shared Access Key invalid", + ), + pytest.param( + "HostName=value.domain.net;WrongValue=Invalid;SharedAccessKey=Zm9vYmFy", + id="Contains extraneous data", + ), + pytest.param("HostName=value.domain.net;DeviceId=my_device", id="Incomplete"), + ], + ) + def test_bad_connection_string( + self, mocker, client_class, edge_local_debug_environment, bad_cs, mock_open + ): + edge_local_debug_environment["EdgeHubConnectionString"] = bad_cs + mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) + + with pytest.raises(ValueError): + client_class.create_from_edge_environment() + + @pytest.mark.it( + "Raises ValueError if the filepath in the EdgeModuleCACertificateFile environment variable is invalid" + ) + def test_bad_filepath(self, mocker, client_class, edge_local_debug_environment, mock_open): + # To make tests compatible with Python 2 & 3, redfine errors + try: + FileNotFoundError # noqa: F823 + except NameError: + FileNotFoundError = IOError + + mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) + my_fnf_error = FileNotFoundError() + mock_open.side_effect = my_fnf_error + with pytest.raises(ValueError) as e_info: + client_class.create_from_edge_environment() + assert e_info.value.__cause__ is my_fnf_error + + @pytest.mark.it( + "Raises ValueError if the file referenced by the filepath in the EdgeModuleCACertificateFile environment variable cannot be opened" + ) + def test_bad_file_io(self, mocker, client_class, edge_local_debug_environment, mock_open): + # Raise a different error in Python 2 vs 3 + if six.PY2: + error = IOError() + else: + error = OSError() + mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) + mock_open.side_effect = error + with pytest.raises(ValueError) as e_info: + client_class.create_from_edge_environment() + assert e_info.value.__cause__ is error + + @pytest.mark.it("Raises ValueError if a SasToken creation results in failure") + def test_raises_value_error_on_sastoken_failure( + self, mocker, client_class, edge_local_debug_environment, mock_open + ): + mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) + sastoken_mock = mocker.patch.object(st, "SasToken") + token_err = st.SasTokenError("Some SasToken failure") + sastoken_mock.side_effect = token_err + + with pytest.raises(ValueError) as e_info: + client_class.create_from_edge_environment() + assert e_info.value.__cause__ is token_err + + +#################### +# HELPER FUNCTIONS # +#################### +def merge_dicts(d1, d2): + d3 = d1.copy() + d3.update(d2) + return d3 diff --git a/azure-iot-device/tests/iothub/test_edge_hsm.py b/azure-iot-device/tests/iothub/test_edge_hsm.py new file mode 100644 index 000000000..121c604ef --- /dev/null +++ b/azure-iot-device/tests/iothub/test_edge_hsm.py @@ -0,0 +1,253 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import pytest +import logging +import requests +import json +import base64 +from six.moves import urllib +from azure.iot.device.iothub.edge_hsm import IoTEdgeHsm, IoTEdgeError +from azure.iot.device import user_agent + + +logging.basicConfig(level=logging.DEBUG) + + +@pytest.fixture +def edge_hsm(): + return IoTEdgeHsm( + module_id="my_module_id", + generation_id="module_generation_id", + workload_uri="unix:///var/run/iotedge/workload.sock", + api_version="my_api_version", + ) + + +@pytest.mark.describe("IoTEdgeHsm - Instantiation") +class TestIoTEdgeHsmInstantiation(object): + @pytest.mark.it("URL encodes the provided module_id parameter and sets it as an attribute") + def test_encode_and_set_module_id(self): + module_id = "my_module_id" + generation_id = "my_generation_id" + api_version = "my_api_version" + workload_uri = "unix:///var/run/iotedge/workload.sock" + + edge_hsm = IoTEdgeHsm( + module_id=module_id, + generation_id=generation_id, + workload_uri=workload_uri, + api_version=api_version, + ) + + assert edge_hsm.module_id == urllib.parse.quote(module_id, safe="") + + @pytest.mark.it( + "Formats the provided workload_uri parameter for use with the requests library and sets it as an attribute" + ) + @pytest.mark.parametrize( + "workload_uri, expected_formatted_uri", + [ + pytest.param( + "unix:///var/run/iotedge/workload.sock", + "http+unix://%2Fvar%2Frun%2Fiotedge%2Fworkload.sock/", + id="Domain Socket URI", + ), + pytest.param("http://127.0.0.1:15580", "http://127.0.0.1:15580/", id="IP Address URI"), + ], + ) + def test_workload_uri_formatting(self, workload_uri, expected_formatted_uri): + module_id = "my_module_id" + generation_id = "my_generation_id" + api_version = "my_api_version" + + edge_hsm = IoTEdgeHsm( + module_id=module_id, + generation_id=generation_id, + workload_uri=workload_uri, + api_version=api_version, + ) + + assert edge_hsm.workload_uri == expected_formatted_uri + + @pytest.mark.it("Sets the provided generation_id parameter as an attribute") + def test_set_generation_id(self): + module_id = "my_module_id" + generation_id = "my_generation_id" + api_version = "my_api_version" + workload_uri = "unix:///var/run/iotedge/workload.sock" + + edge_hsm = IoTEdgeHsm( + module_id=module_id, + generation_id=generation_id, + workload_uri=workload_uri, + api_version=api_version, + ) + + assert edge_hsm.generation_id == generation_id + + @pytest.mark.it("Sets the provided api_version parameter as an attribute") + def test_set_api_verison(self): + module_id = "my_module_id" + generation_id = "my_generation_id" + api_version = "my_api_version" + workload_uri = "unix:///var/run/iotedge/workload.sock" + + edge_hsm = IoTEdgeHsm( + module_id=module_id, + generation_id=generation_id, + workload_uri=workload_uri, + api_version=api_version, + ) + + assert edge_hsm.api_version == api_version + + +@pytest.mark.describe("IoTEdgeHsm - .get_certificate()") +class TestIoTEdgeHsmGetCertificate(object): + @pytest.mark.it("Sends an HTTP GET request to retrieve the trust bundle from Edge") + def test_requests_trust_bundle(self, mocker, edge_hsm): + mock_request_get = mocker.patch.object(requests, "get") + expected_url = edge_hsm.workload_uri + "trust-bundle" + expected_params = {"api-version": edge_hsm.api_version} + expected_headers = { + "User-Agent": urllib.parse.quote_plus(user_agent.get_iothub_user_agent()) + } + + edge_hsm.get_certificate() + + assert mock_request_get.call_count == 1 + assert mock_request_get.call_args == mocker.call( + expected_url, params=expected_params, headers=expected_headers + ) + + @pytest.mark.it("Returns the certificate from the trust bundle received from Edge") + def test_returns_certificate(self, mocker, edge_hsm): + mock_request_get = mocker.patch.object(requests, "get") + mock_response = mock_request_get.return_value + certificate = "my certificate" + mock_response.json.return_value = {"certificate": certificate} + + returned_cert = edge_hsm.get_certificate() + + assert returned_cert is certificate + + @pytest.mark.it("Raises IoTEdgeError if a bad request is made to Edge") + def test_bad_request(self, mocker, edge_hsm): + mock_request_get = mocker.patch.object(requests, "get") + mock_response = mock_request_get.return_value + error = requests.exceptions.HTTPError() + mock_response.raise_for_status.side_effect = error + + with pytest.raises(IoTEdgeError) as e_info: + edge_hsm.get_certificate() + assert e_info.value.__cause__ is error + + @pytest.mark.it("Raises IoTEdgeError if there is an error in json decoding the trust bundle") + def test_bad_json(self, mocker, edge_hsm): + mock_request_get = mocker.patch.object(requests, "get") + mock_response = mock_request_get.return_value + error = ValueError() + mock_response.json.side_effect = error + + with pytest.raises(IoTEdgeError) as e_info: + edge_hsm.get_certificate() + assert e_info.value.__cause__ is error + + @pytest.mark.it("Raises IoTEdgeError if the certificate is missing from the trust bundle") + def test_bad_trust_bundle(self, mocker, edge_hsm): + mock_request_get = mocker.patch.object(requests, "get") + mock_response = mock_request_get.return_value + # Return an empty json dict with no 'certificate' key + mock_response.json.return_value = {} + + with pytest.raises(IoTEdgeError): + edge_hsm.get_certificate() + + +@pytest.mark.describe("IoTEdgeHsm - .sign()") +class TestIoTEdgeHsmSign(object): + @pytest.mark.it( + "Makes an HTTP request to Edge to sign a piece of string data using the HMAC-SHA256 algorithm" + ) + def test_requests_data_signing(self, mocker, edge_hsm): + data_str = "somedata" + data_str_b64 = "c29tZWRhdGE=" + mock_request_post = mocker.patch.object(requests, "post") + mock_request_post.return_value.json.return_value = {"digest": "somedigest"} + expected_url = "{workload_uri}modules/{module_id}/genid/{generation_id}/sign".format( + workload_uri=edge_hsm.workload_uri, + module_id=edge_hsm.module_id, + generation_id=edge_hsm.generation_id, + ) + expected_params = {"api-version": edge_hsm.api_version} + expected_headers = { + "User-Agent": urllib.parse.quote(user_agent.get_iothub_user_agent(), safe="") + } + expected_json = json.dumps({"keyId": "primary", "algo": "HMACSHA256", "data": data_str_b64}) + + edge_hsm.sign(data_str) + + assert mock_request_post.call_count == 1 + assert mock_request_post.call_args == mocker.call( + url=expected_url, params=expected_params, headers=expected_headers, data=expected_json + ) + + @pytest.mark.it("Base64 encodes the string data in the request") + def test_b64_encodes_data(self, mocker, edge_hsm): + # This test is actually implicitly tested in the first test, but it's + # important to have an explicit test for it since it's a requirement + data_str = "somedata" + data_str_b64 = base64.b64encode(data_str.encode("utf-8")).decode() + mock_request_post = mocker.patch.object(requests, "post") + mock_request_post.return_value.json.return_value = {"digest": "somedigest"} + + edge_hsm.sign(data_str) + + sent_data = json.loads(mock_request_post.call_args[1]["data"])["data"] + + assert data_str != data_str_b64 + assert sent_data == data_str_b64 + + @pytest.mark.it("Returns the signed data received from Edge") + def test_returns_signed_data(self, mocker, edge_hsm): + expected_digest = "somedigest" + mock_request_post = mocker.patch.object(requests, "post") + mock_request_post.return_value.json.return_value = {"digest": expected_digest} + + signed_data = edge_hsm.sign("somedata") + + assert signed_data == expected_digest + + @pytest.mark.it("Raises IoTEdgeError if a bad request is made to EdgeHub") + def test_bad_request(self, mocker, edge_hsm): + mock_request_post = mocker.patch.object(requests, "post") + mock_response = mock_request_post.return_value + error = requests.exceptions.HTTPError() + mock_response.raise_for_status.side_effect = error + + with pytest.raises(IoTEdgeError) as e_info: + edge_hsm.sign("somedata") + assert e_info.value.__cause__ is error + + @pytest.mark.it("Raises IoTEdgeError if there is an error in json decoding the signed response") + def test_bad_json(self, mocker, edge_hsm): + mock_request_post = mocker.patch.object(requests, "post") + mock_response = mock_request_post.return_value + error = ValueError() + mock_response.json.side_effect = error + with pytest.raises(IoTEdgeError) as e_info: + edge_hsm.sign("somedata") + assert e_info.value.__cause__ is error + + @pytest.mark.it("Raises IoTEdgeError if the signed data is missing from the response") + def test_bad_response(self, mocker, edge_hsm): + mock_request_post = mocker.patch.object(requests, "post") + mock_response = mock_request_post.return_value + mock_response.json.return_value = {} + + with pytest.raises(IoTEdgeError): + edge_hsm.sign("somedata") diff --git a/azure-iot-device/tests/iothub/test_sync_clients.py b/azure-iot-device/tests/iothub/test_sync_clients.py index eb1ab614a..9329d67d7 100644 --- a/azure-iot-device/tests/iothub/test_sync_clients.py +++ b/azure-iot-device/tests/iothub/test_sync_clients.py @@ -13,276 +13,28 @@ import six from azure.iot.device.iothub import IoTHubDeviceClient, IoTHubModuleClient from azure.iot.device import exceptions as client_exceptions -from azure.iot.device.iothub.pipeline import MQTTPipeline, constant +from azure.iot.device.iothub.pipeline import constant from azure.iot.device.iothub.pipeline import exceptions as pipeline_exceptions from azure.iot.device.iothub.models import Message, MethodRequest from azure.iot.device.iothub.sync_inbox import SyncClientInbox -from azure.iot.device.iothub.auth import IoTEdgeError from azure.iot.device import constant as device_constant +from .shared_client_tests import ( + SharedIoTHubClientInstantiationTests, + SharedIoTHubClientPROPERTYConnectedTests, + SharedIoTHubClientCreateFromConnectionStringTests, + SharedIoTHubDeviceClientCreateFromSymmetricKeyTests, + SharedIoTHubDeviceClientCreateFromX509CertificateTests, + SharedIoTHubModuleClientCreateFromX509CertificateTests, + SharedIoTHubModuleClientCreateFromEdgeEnvironmentWithContainerEnvTests, + SharedIoTHubModuleClientCreateFromEdgeEnvironmentWithDebugEnvTests, +) logging.basicConfig(level=logging.DEBUG) -# automatically mock the mqtt pipeline for all tests in this file. -@pytest.fixture(autouse=True) -def mock_mqtt_pipeline_init(mocker): - return mocker.patch("azure.iot.device.iothub.pipeline.MQTTPipeline") - - -# automatically mock the http pipeline for all tests in this file. -@pytest.fixture(autouse=True) -def mock_http_pipeline_init(mocker): - return mocker.patch("azure.iot.device.iothub.pipeline.HTTPPipeline") - - -################ -# SHARED TESTS # -################ -class SharedClientInstantiationTests(object): - @pytest.mark.it( - "Stores the MQTTPipeline from the 'mqtt_pipeline' parameter in the '_mqtt_pipeline' attribute" - ) - def test_mqtt_pipeline_attribute(self, client_class, mqtt_pipeline, http_pipeline): - client = client_class(mqtt_pipeline, http_pipeline) - - assert client._mqtt_pipeline is mqtt_pipeline - - @pytest.mark.it( - "Stores the HTTPPipeline from the 'http_pipeline' parameter in the '_http_pipeline' attribute" - ) - def test_sets_http_pipeline_attribute(self, client_class, mqtt_pipeline, http_pipeline): - client = client_class(mqtt_pipeline, http_pipeline) - - assert client._http_pipeline is http_pipeline - - @pytest.mark.it("Sets on_connected handler in the MQTTPipeline") - def test_sets_on_connected_handler_in_pipeline( - self, client_class, mqtt_pipeline, http_pipeline - ): - client = client_class(mqtt_pipeline, http_pipeline) - - assert client._mqtt_pipeline.on_connected is not None - assert client._mqtt_pipeline.on_connected == client._on_connected - - @pytest.mark.it("Sets on_disconnected handler in the MQTTPipeline") - def test_sets_on_disconnected_handler_in_pipeline( - self, client_class, mqtt_pipeline, http_pipeline - ): - client = client_class(mqtt_pipeline, http_pipeline) - - assert client._mqtt_pipeline.on_disconnected is not None - assert client._mqtt_pipeline.on_disconnected == client._on_disconnected - - @pytest.mark.it("Sets on_method_request_received handler in the MQTTPipeline") - def test_sets_on_method_request_received_handler_in_pipleline( - self, client_class, mqtt_pipeline, http_pipeline - ): - client = client_class(mqtt_pipeline, http_pipeline) - - assert client._mqtt_pipeline.on_method_request_received is not None - assert ( - client._mqtt_pipeline.on_method_request_received - == client._inbox_manager.route_method_request - ) - - -class SharedClientCreateMethodUserOptionTests(object): - # In these tests we patch the entire 'auth' library instead of specific auth providers in order - # to make them more generic, and applicable across all creation methods. - - @pytest.fixture - def option_test_required_patching(self, mocker): - """Override this fixture in a subclass if unique patching is required""" - pass - - @pytest.mark.it( - "Sets the 'product_info' user option parameter on the PipelineConfig, if provided" - ) - def test_product_info_option( - self, - option_test_required_patching, - client_create_method, - create_method_args, - mock_mqtt_pipeline_init, - mock_http_pipeline_init, - ): - - product_info = "MyProductInfo" - client_create_method(*create_method_args, product_info=product_info) - - # Get configuration object, and ensure it was used for both protocol pipelines - assert mock_mqtt_pipeline_init.call_count == 1 - config = mock_mqtt_pipeline_init.call_args[0][1] - assert config == mock_http_pipeline_init.call_args[0][1] - - assert config.product_info == product_info - - @pytest.mark.it( - "Sets the 'websockets' user option parameter on the PipelineConfig, if provided" - ) - def test_websockets_option( - self, - option_test_required_patching, - client_create_method, - create_method_args, - mock_mqtt_pipeline_init, - mock_http_pipeline_init, - ): - - client_create_method(*create_method_args, websockets=True) - - # Get configuration object, and ensure it was used for both protocol pipelines - assert mock_mqtt_pipeline_init.call_count == 1 - config = mock_mqtt_pipeline_init.call_args[0][1] - assert config == mock_http_pipeline_init.call_args[0][1] - - assert config.websockets - - # TODO: Show that input in the wrong format is formatted to the correct one. This test exists - # in the IoTHubPipelineConfig object already, but we do not currently show that this is felt - # from the API level. - @pytest.mark.it("Sets the 'cipher' user option parameter on the PipelineConfig, if provided") - def test_cipher_option( - self, - option_test_required_patching, - client_create_method, - create_method_args, - mock_mqtt_pipeline_init, - mock_http_pipeline_init, - ): - cipher = "DHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA:ECDHE-ECDSA-AES128-GCM-SHA256" - client_create_method(*create_method_args, cipher=cipher) - - # Get configuration object, and ensure it was used for both protocol pipelines - assert mock_mqtt_pipeline_init.call_count == 1 - config = mock_mqtt_pipeline_init.call_args[0][1] - assert config == mock_http_pipeline_init.call_args[0][1] - - assert config.cipher == cipher - - @pytest.mark.it( - "Sets the 'server_verification_cert' user option parameter on the AuthenticationProvider, if provided" - ) - def test_server_verification_cert_option( - self, - option_test_required_patching, - client_create_method, - create_method_args, - mock_mqtt_pipeline_init, - mock_http_pipeline_init, - ): - server_verification_cert = "fake_server_verification_cert" - client_create_method(*create_method_args, server_verification_cert=server_verification_cert) - - # Get auth provider object, and ensure it was used for both protocol pipelines - auth = mock_mqtt_pipeline_init.call_args[0][0] - assert auth == mock_http_pipeline_init.call_args[0][0] - - assert auth.server_verification_cert == server_verification_cert - - @pytest.mark.it("Raises a TypeError if an invalid user option parameter is provided") - def test_invalid_option( - self, option_test_required_patching, client_create_method, create_method_args - ): - with pytest.raises(TypeError): - client_create_method(*create_method_args, invalid_option="some_value") - - @pytest.mark.it("Sets default user options if none are provided") - def test_default_options( - self, - mocker, - option_test_required_patching, - client_create_method, - create_method_args, - mock_mqtt_pipeline_init, - mock_http_pipeline_init, - ): - mock_config = mocker.patch("azure.iot.device.iothub.pipeline.IoTHubPipelineConfig") - - client_create_method(*create_method_args) - - # Pipeline Config was instantiated with default arguments - assert mock_config.call_count == 1 - expected_kwargs = {} - assert mock_config.call_args == mocker.call(**expected_kwargs) - - # This default config was used for both protocol pipelines - assert mock_mqtt_pipeline_init.call_count == 1 - assert mock_mqtt_pipeline_init.call_args[0][1] == mock_config.return_value - assert mock_http_pipeline_init.call_args[0][1] == mock_config.return_value - - # Get auth provider object, and ensure it was used for both protocol pipelines - auth = mock_mqtt_pipeline_init.call_args[0][0] - assert auth == mock_http_pipeline_init.call_args[0][0] - - # Ensure that auth options are set to expected defaults - assert auth.server_verification_cert is None - - -class SharedClientCreateFromConnectionStringTests(object): - @pytest.mark.it("Uses the connection string to create a SymmetricKeyAuthenticationProvider") - def test_auth_provider_creation(self, mocker, client_class, connection_string): - mock_auth_parse = mocker.patch( - "azure.iot.device.iothub.auth.SymmetricKeyAuthenticationProvider" - ).parse - - client_class.create_from_connection_string(connection_string) - - assert mock_auth_parse.call_count == 1 - assert mock_auth_parse.call_args == mocker.call(connection_string) - - @pytest.mark.it("Uses the SymmetricKeyAuthenticationProvider to create an MQTTPipeline") - def test_pipeline_creation( - self, mocker, client_class, connection_string, mock_mqtt_pipeline_init - ): - mock_auth = mocker.patch( - "azure.iot.device.iothub.auth.SymmetricKeyAuthenticationProvider" - ).parse.return_value - - mock_config_init = mocker.patch("azure.iot.device.iothub.pipeline.IoTHubPipelineConfig") - - client_class.create_from_connection_string(connection_string) - - assert mock_mqtt_pipeline_init.call_count == 1 - assert mock_mqtt_pipeline_init.call_args == mocker.call( - mock_auth, mock_config_init.return_value - ) - - @pytest.mark.it("Uses the MQTTPipeline to instantiate the client") - def test_client_instantiation(self, mocker, client_class, connection_string): - mock_pipeline = mocker.patch("azure.iot.device.iothub.pipeline.MQTTPipeline").return_value - mock_pipeline_http = mocker.patch( - "azure.iot.device.iothub.pipeline.HTTPPipeline" - ).return_value - spy_init = mocker.spy(client_class, "__init__") - client_class.create_from_connection_string(connection_string) - assert spy_init.call_count == 1 - assert spy_init.call_args == mocker.call(mocker.ANY, mock_pipeline, mock_pipeline_http) - - @pytest.mark.it("Returns the instantiated client") - def test_returns_client(self, client_class, connection_string): - client = client_class.create_from_connection_string(connection_string) - - assert isinstance(client, client_class) - - # TODO: If auth package was refactored to use ConnectionString class, tests from that - # class would increase the coverage here. - @pytest.mark.it("Raises ValueError when given an invalid connection string") - @pytest.mark.parametrize( - "bad_cs", - [ - pytest.param("not-a-connection-string", id="Garbage string"), - pytest.param(object(), id="Non-string input"), - pytest.param( - "HostName=Invalid;DeviceId=Invalid;SharedAccessKey=Invalid", - id="Malformed Connection String", - marks=pytest.mark.xfail(reason="Bug in pipeline + need for auth refactor"), # TODO - ), - ], - ) - def test_raises_value_error_on_bad_connection_string(self, client_class, bad_cs): - with pytest.raises(ValueError): - client_class.create_from_connection_string(bad_cs) +################## +# INFRASTRUCTURE # +################## class WaitsForEventCompletion(object): @@ -309,6 +61,11 @@ def check_callback_completes_event(): event_mock.wait.side_effect = check_callback_completes_event +####################### +# SHARED CLIENT TESTS # +####################### + + class SharedClientConnectTests(WaitsForEventCompletion): @pytest.mark.it("Begins a 'connect' pipeline operation") def test_calls_pipeline_connect(self, client, mqtt_pipeline): @@ -1075,20 +832,6 @@ def test_no_message_in_inbox_nonblocking_mode(self, client): assert result is None -class SharedClientPROPERTYConnectedTests(object): - @pytest.mark.it("Cannot be changed") - def test_read_only(self, client): - with pytest.raises(AttributeError): - client.connected = not client.connected - - @pytest.mark.it("Reflects the value of the root stage property of the same name") - def test_reflects_pipeline_property(self, client, mqtt_pipeline): - mqtt_pipeline.connected = True - assert client.connected - mqtt_pipeline.connected = False - assert not client.connected - - ################ # DEVICE TESTS # ################ @@ -1125,7 +868,7 @@ def sas_token_string(self, device_sas_token_string): @pytest.mark.describe("IoTHubDeviceClient (Synchronous) - Instantiation") class TestIoTHubDeviceClientInstantiation( - IoTHubDeviceClientTestsConfig, SharedClientInstantiationTests + IoTHubDeviceClientTestsConfig, SharedIoTHubClientInstantiationTests ): @pytest.mark.it("Sets on_c2d_message_received handler in the MQTTPipeline") def test_sets_on_c2d_message_received_handler_in_pipeline( @@ -1141,173 +884,23 @@ def test_sets_on_c2d_message_received_handler_in_pipeline( @pytest.mark.describe("IoTHubDeviceClient (Synchronous) - .create_from_connection_string()") class TestIoTHubDeviceClientCreateFromConnectionString( - IoTHubDeviceClientTestsConfig, - SharedClientCreateMethodUserOptionTests, - SharedClientCreateFromConnectionStringTests, + IoTHubDeviceClientTestsConfig, SharedIoTHubClientCreateFromConnectionStringTests ): - @pytest.fixture - def client_create_method(self, client_class): - """Provides the specific create method for use in universal tests""" - return client_class.create_from_connection_string - - @pytest.fixture - def create_method_args(self, connection_string): - """Provides the specific create method args for use in universal tests""" - return [connection_string] + pass @pytest.mark.describe("IoTHubDeviceClient (Synchronous) - .create_from_symmetric_key()") class TestIoTHubDeviceClientCreateFromSymmetricKey( - IoTHubDeviceClientTestsConfig, SharedClientCreateMethodUserOptionTests + IoTHubDeviceClientTestsConfig, SharedIoTHubDeviceClientCreateFromSymmetricKeyTests ): - @pytest.fixture - def client_create_method(self, client_class): - """Provides the specific create method for use in universal tests""" - return client_class.create_from_symmetric_key - - @pytest.fixture - def create_method_args(self, symmetric_key, hostname_fixture, device_id_fixture): - """Provides the specific create method args for use in universal tests""" - return [symmetric_key, hostname_fixture, device_id_fixture] - - @pytest.mark.it("Uses the symmetric key to create a SymmetricKeyAuthenticationProvider") - def test_auth_provider_creation( - self, mocker, client_class, symmetric_key, hostname_fixture, device_id_fixture - ): - mock_auth_init = mocker.patch( - "azure.iot.device.iothub.auth.SymmetricKeyAuthenticationProvider" - ) - - client_class.create_from_symmetric_key( - symmetric_key=symmetric_key, hostname=hostname_fixture, device_id=device_id_fixture - ) - - assert mock_auth_init.call_count == 1 - assert mock_auth_init.call_args == mocker.call( - hostname=hostname_fixture, - device_id=device_id_fixture, - module_id=None, - shared_access_key=symmetric_key, - ) - - @pytest.mark.it("Uses the SymmetricKeyAuthenticationProvider to create an MQTTPipeline") - def test_pipeline_creation( - self, - mocker, - client_class, - symmetric_key, - hostname_fixture, - device_id_fixture, - mock_mqtt_pipeline_init, - ): - mock_auth = mocker.patch( - "azure.iot.device.iothub.auth.SymmetricKeyAuthenticationProvider" - ).return_value - - mock_config_init = mocker.patch("azure.iot.device.iothub.pipeline.IoTHubPipelineConfig") - - client_class.create_from_symmetric_key( - symmetric_key=symmetric_key, hostname=hostname_fixture, device_id=device_id_fixture - ) - - assert mock_mqtt_pipeline_init.call_count == 1 - assert mock_mqtt_pipeline_init.call_args == mocker.call( - mock_auth, mock_config_init.return_value - ) - - @pytest.mark.it("Uses the MQTTPipeline to instantiate the client") - def test_client_instantiation( - self, mocker, client_class, symmetric_key, hostname_fixture, device_id_fixture - ): - mock_pipeline = mocker.patch("azure.iot.device.iothub.pipeline.MQTTPipeline").return_value - mock_pipeline_http = mocker.patch( - "azure.iot.device.iothub.pipeline.HTTPPipeline" - ).return_value - spy_init = mocker.spy(client_class, "__init__") - client_class.create_from_symmetric_key( - symmetric_key=symmetric_key, hostname=hostname_fixture, device_id=device_id_fixture - ) - assert spy_init.call_count == 1 - assert spy_init.call_args == mocker.call(mocker.ANY, mock_pipeline, mock_pipeline_http) - - @pytest.mark.it("Returns the instantiated client") - def test_returns_client(self, client_class, symmetric_key, hostname_fixture, device_id_fixture): - client = client_class.create_from_symmetric_key( - symmetric_key=symmetric_key, hostname=hostname_fixture, device_id=device_id_fixture - ) - - assert isinstance(client, client_class) + pass @pytest.mark.describe("IoTHubDeviceClient (Synchronous) - .create_from_x509_certificate()") class TestIoTHubDeviceClientCreateFromX509Certificate( - IoTHubDeviceClientTestsConfig, SharedClientCreateMethodUserOptionTests + IoTHubDeviceClientTestsConfig, SharedIoTHubDeviceClientCreateFromX509CertificateTests ): - hostname = "durmstranginstitute.farend" - device_id = "MySnitch" - - @pytest.fixture - def client_create_method(self, client_class): - """Provides the specific create method for use in universal tests""" - return client_class.create_from_x509_certificate - - @pytest.fixture - def create_method_args(self, x509): - """Provides the specific create method args for use in universal tests""" - return [x509, self.hostname, self.device_id] - - @pytest.mark.it("Uses the provided arguments to create a X509AuthenticationProvider") - def test_auth_provider_creation(self, mocker, client_class, x509): - mock_auth_init = mocker.patch("azure.iot.device.iothub.auth.X509AuthenticationProvider") - - client_class.create_from_x509_certificate( - x509=x509, hostname=self.hostname, device_id=self.device_id - ) - - assert mock_auth_init.call_count == 1 - assert mock_auth_init.call_args == mocker.call( - x509=x509, hostname=self.hostname, device_id=self.device_id - ) - - @pytest.mark.it("Uses the X509AuthenticationProvider to create an MQTTPipeline") - def test_pipeline_creation(self, mocker, client_class, x509, mock_mqtt_pipeline_init): - mock_auth = mocker.patch( - "azure.iot.device.iothub.auth.X509AuthenticationProvider" - ).return_value - - mock_config_init = mocker.patch("azure.iot.device.iothub.pipeline.IoTHubPipelineConfig") - - client_class.create_from_x509_certificate( - x509=x509, hostname=self.hostname, device_id=self.device_id - ) - - assert mock_mqtt_pipeline_init.call_count == 1 - assert mock_mqtt_pipeline_init.call_args == mocker.call( - mock_auth, mock_config_init.return_value - ) - - @pytest.mark.it("Uses the MQTTPipeline to instantiate the client") - def test_client_instantiation(self, mocker, client_class, x509): - mock_pipeline = mocker.patch("azure.iot.device.iothub.pipeline.MQTTPipeline").return_value - mock_pipeline_http = mocker.patch( - "azure.iot.device.iothub.pipeline.HTTPPipeline" - ).return_value - spy_init = mocker.spy(client_class, "__init__") - - client_class.create_from_x509_certificate( - x509=x509, hostname=self.hostname, device_id=self.device_id - ) - - assert spy_init.call_count == 1 - assert spy_init.call_args == mocker.call(mocker.ANY, mock_pipeline, mock_pipeline_http) - - @pytest.mark.it("Returns the instantiated client") - def test_returns_client(self, mocker, client_class, x509): - client = client_class.create_from_x509_certificate( - x509=x509, hostname=self.hostname, device_id=self.device_id - ) - - assert isinstance(client, client_class) + pass @pytest.mark.describe("IoTHubDeviceClient (Synchronous) - .connect()") @@ -1601,7 +1194,7 @@ def test_raises_error_on_pipeline_op_error( @pytest.mark.describe("IoTHubDeviceClient (Synchronous) - PROPERTY .connected") class TestIoTHubDeviceClientPROPERTYConnected( - IoTHubDeviceClientTestsConfig, SharedClientPROPERTYConnectedTests + IoTHubDeviceClientTestsConfig, SharedIoTHubClientPROPERTYConnectedTests ): pass @@ -1642,7 +1235,7 @@ def sas_token_string(self, module_sas_token_string): @pytest.mark.describe("IoTHubModuleClient (Synchronous) - Instantiation") class TestIoTHubModuleClientInstantiation( - IoTHubModuleClientTestsConfig, SharedClientInstantiationTests + IoTHubModuleClientTestsConfig, SharedIoTHubClientInstantiationTests ): @pytest.mark.it("Sets on_input_message_received handler in the MQTTPipeline") def test_sets_on_input_message_received_handler_in_pipeline( @@ -1659,521 +1252,36 @@ def test_sets_on_input_message_received_handler_in_pipeline( @pytest.mark.describe("IoTHubModuleClient (Synchronous) - .create_from_connection_string()") class TestIoTHubModuleClientCreateFromConnectionString( - IoTHubModuleClientTestsConfig, - SharedClientCreateMethodUserOptionTests, - SharedClientCreateFromConnectionStringTests, -): - @pytest.fixture - def client_create_method(self, client_class): - """Provides the specific create method for use in universal tests""" - return client_class.create_from_connection_string - - @pytest.fixture - def create_method_args(self, connection_string): - """Provides the specific create method args for use in universal tests""" - return [connection_string] - - -class IoTHubModuleClientClientCreateFromEdgeEnvironmentUserOptionTests( - SharedClientCreateMethodUserOptionTests + IoTHubModuleClientTestsConfig, SharedIoTHubClientCreateFromConnectionStringTests ): - """This class inherites the user option tests shared by all create method APIs, and overrides - tests in order to accomodate unique requirements for the .create_from_edge_enviornment() method. - - Because .create_from_edge_environment() tests are spread accross multiple test units - (i.e. test classes), these overrides are done in this class, which is then inherited by all - .create_from_edge_environment() test units below. - """ - - @pytest.fixture - def client_create_method(self, client_class): - """Provides the specific create method for use in universal tests""" - return client_class.create_from_edge_environment - - @pytest.fixture - def create_method_args(self): - """Provides the specific create method args for use in universal tests""" - return [] - - @pytest.mark.it( - "Raises a TypeError if the 'server_verification_cert' user option parameter is provided" - ) - def test_server_verification_cert_option( - self, - option_test_required_patching, - client_create_method, - create_method_args, - mock_mqtt_pipeline_init, - mock_http_pipeline_init, - ): - """THIS TEST OVERRIDES AN INHERITED TEST""" - - with pytest.raises(TypeError): - client_create_method( - *create_method_args, server_verification_cert="fake_server_verification_cert" - ) - - @pytest.mark.it("Sets default user options if none are provided") - def test_default_options( - self, - mocker, - option_test_required_patching, - client_create_method, - create_method_args, - mock_mqtt_pipeline_init, - mock_http_pipeline_init, - ): - """THIS TEST OVERRIDES AN INHERITED TEST""" - mock_config = mocker.patch("azure.iot.device.iothub.pipeline.IoTHubPipelineConfig") - - client_create_method(*create_method_args) - - # Pipeline Config was instantiated with default arguments - assert mock_config.call_count == 1 - expected_kwargs = {} - assert mock_config.call_args == mocker.call(**expected_kwargs) - - # This default config was used for both protocol pipelines - assert mock_mqtt_pipeline_init.call_count == 1 - assert mock_mqtt_pipeline_init.call_args[0][1] == mock_config.return_value - assert mock_http_pipeline_init.call_args[0][1] == mock_config.return_value - - # Get auth provider object, and ensure it was used for both protocol pipelines - auth = mock_mqtt_pipeline_init.call_args[0][0] - assert auth == mock_http_pipeline_init.call_args[0][0] + pass @pytest.mark.describe( "IoTHubModuleClient (Synchronous) - .create_from_edge_environment() -- Edge Container Environment" ) class TestIoTHubModuleClientCreateFromEdgeEnvironmentWithContainerEnv( - IoTHubModuleClientTestsConfig, IoTHubModuleClientClientCreateFromEdgeEnvironmentUserOptionTests + IoTHubModuleClientTestsConfig, + SharedIoTHubModuleClientCreateFromEdgeEnvironmentWithContainerEnvTests, ): - @pytest.fixture - def option_test_required_patching(self, mocker, edge_container_environment): - """THIS FIXTURE OVERRIDES AN INHERITED FIXTURE""" - mocker.patch("azure.iot.device.iothub.auth.IoTEdgeAuthenticationProvider") - mocker.patch.dict(os.environ, edge_container_environment, clear=True) - - @pytest.mark.it( - "Uses Edge container environment variables to create an IoTEdgeAuthenticationProvider" - ) - def test_auth_provider_creation(self, mocker, client_class, edge_container_environment): - mocker.patch.dict(os.environ, edge_container_environment, clear=True) - mock_auth_init = mocker.patch("azure.iot.device.iothub.auth.IoTEdgeAuthenticationProvider") - - client_class.create_from_edge_environment() - - assert mock_auth_init.call_count == 1 - assert mock_auth_init.call_args == mocker.call( - hostname=edge_container_environment["IOTEDGE_IOTHUBHOSTNAME"], - device_id=edge_container_environment["IOTEDGE_DEVICEID"], - module_id=edge_container_environment["IOTEDGE_MODULEID"], - gateway_hostname=edge_container_environment["IOTEDGE_GATEWAYHOSTNAME"], - module_generation_id=edge_container_environment["IOTEDGE_MODULEGENERATIONID"], - workload_uri=edge_container_environment["IOTEDGE_WORKLOADURI"], - api_version=edge_container_environment["IOTEDGE_APIVERSION"], - ) - - @pytest.mark.it( - "Ignores any Edge local debug environment variables that may be present, in favor of using Edge container variables" - ) - def test_auth_provider_creation_hybrid_env( - self, mocker, client_class, edge_container_environment, edge_local_debug_environment - ): - # This test verifies that with a hybrid environment, the auth provider will always be - # an IoTEdgeAuthenticationProvider, even if local debug variables are present - hybrid_environment = merge_dicts(edge_container_environment, edge_local_debug_environment) - mocker.patch.dict(os.environ, hybrid_environment, clear=True) - mock_edge_auth_init = mocker.patch( - "azure.iot.device.iothub.auth.IoTEdgeAuthenticationProvider" - ) - mock_sk_auth_parse = mocker.patch( - "azure.iot.device.iothub.auth.SymmetricKeyAuthenticationProvider" - ).parse - - client_class.create_from_edge_environment() - - assert mock_edge_auth_init.call_count == 1 - assert mock_sk_auth_parse.call_count == 0 # we did NOT use SK auth - assert mock_edge_auth_init.call_args == mocker.call( - hostname=edge_container_environment["IOTEDGE_IOTHUBHOSTNAME"], - device_id=edge_container_environment["IOTEDGE_DEVICEID"], - module_id=edge_container_environment["IOTEDGE_MODULEID"], - gateway_hostname=edge_container_environment["IOTEDGE_GATEWAYHOSTNAME"], - module_generation_id=edge_container_environment["IOTEDGE_MODULEGENERATIONID"], - workload_uri=edge_container_environment["IOTEDGE_WORKLOADURI"], - api_version=edge_container_environment["IOTEDGE_APIVERSION"], - ) - - @pytest.mark.it( - "Uses the IoTEdgeAuthenticationProvider to create an MQTTPipeline and an HTTPPipeline" - ) - def test_pipeline_creation(self, mocker, client_class, edge_container_environment): - mocker.patch.dict(os.environ, edge_container_environment, clear=True) - mock_auth = mocker.patch( - "azure.iot.device.iothub.auth.IoTEdgeAuthenticationProvider" - ).return_value - mock_config_init = mocker.patch("azure.iot.device.iothub.pipeline.IoTHubPipelineConfig") - mock_mqtt_pipeline_init = mocker.patch("azure.iot.device.iothub.pipeline.MQTTPipeline") - mock_http_pipeline_init = mocker.patch("azure.iot.device.iothub.pipeline.HTTPPipeline") - - client_class.create_from_edge_environment() - - assert mock_mqtt_pipeline_init.call_count == 1 - assert mock_mqtt_pipeline_init.call_args == mocker.call( - mock_auth, mock_config_init.return_value - ) - assert mock_http_pipeline_init.call_count == 1 - # This asserts without mock_config_init because currently edge isn't implemented. When it is, this should be identical to the line aboe. - assert mock_http_pipeline_init.call_args == mocker.call( - mock_auth, mock_config_init.return_value - ) - - @pytest.mark.it("Uses the MQTTPipeline and the HTTPPipeline to instantiate the client") - def test_client_instantiation(self, mocker, client_class, edge_container_environment): - mocker.patch.dict(os.environ, edge_container_environment, clear=True) - # Always patch the IoTEdgeAuthenticationProvider to prevent I/O operations - mocker.patch("azure.iot.device.iothub.auth.IoTEdgeAuthenticationProvider") - mock_mqtt_pipeline = mocker.patch( - "azure.iot.device.iothub.pipeline.MQTTPipeline" - ).return_value - mock_http_pipeline = mocker.patch( - "azure.iot.device.iothub.pipeline.HTTPPipeline" - ).return_value - spy_init = mocker.spy(client_class, "__init__") - - client_class.create_from_edge_environment() - - assert spy_init.call_count == 1 - assert spy_init.call_args == mocker.call(mocker.ANY, mock_mqtt_pipeline, mock_http_pipeline) - - @pytest.mark.it("Returns the instantiated client") - def test_returns_client(self, mocker, client_class, edge_container_environment): - mocker.patch.dict(os.environ, edge_container_environment, clear=True) - # Always patch the IoTEdgeAuthenticationProvider to prevent I/O operations - mocker.patch("azure.iot.device.iothub.auth.IoTEdgeAuthenticationProvider") - - client = client_class.create_from_edge_environment() - - assert isinstance(client, client_class) - - @pytest.mark.it("Raises OSError if the environment is missing required variables") - @pytest.mark.parametrize( - "missing_env_var", - [ - "IOTEDGE_MODULEID", - "IOTEDGE_DEVICEID", - "IOTEDGE_IOTHUBHOSTNAME", - "IOTEDGE_GATEWAYHOSTNAME", - "IOTEDGE_APIVERSION", - "IOTEDGE_MODULEGENERATIONID", - "IOTEDGE_WORKLOADURI", - ], - ) - def test_bad_environment( - self, mocker, client_class, edge_container_environment, missing_env_var - ): - # Remove a variable from the fixture - del edge_container_environment[missing_env_var] - mocker.patch.dict(os.environ, edge_container_environment, clear=True) - - with pytest.raises(OSError): - client_class.create_from_edge_environment() - - @pytest.mark.it("Raises OSError if there is an error using the Edge for authentication") - def test_bad_edge_auth(self, mocker, client_class, edge_container_environment): - mocker.patch.dict(os.environ, edge_container_environment, clear=True) - mock_auth = mocker.patch("azure.iot.device.iothub.auth.IoTEdgeAuthenticationProvider") - my_edge_error = IoTEdgeError() - mock_auth.side_effect = my_edge_error - - with pytest.raises(OSError) as e_info: - client_class.create_from_edge_environment() - assert e_info.value.__cause__ is my_edge_error + pass @pytest.mark.describe( "IoTHubModuleClient (Synchronous) - .create_from_edge_environment() -- Edge Local Debug Environment" ) class TestIoTHubModuleClientCreateFromEdgeEnvironmentWithDebugEnv( - IoTHubModuleClientTestsConfig, IoTHubModuleClientClientCreateFromEdgeEnvironmentUserOptionTests + IoTHubModuleClientTestsConfig, + SharedIoTHubModuleClientCreateFromEdgeEnvironmentWithDebugEnvTests, ): - @pytest.fixture - def option_test_required_patching(self, mocker, edge_local_debug_environment): - """THIS FIXTURE OVERRIDES AN INHERITED FIXTURE""" - mocker.patch("azure.iot.device.iothub.auth.SymmetricKeyAuthenticationProvider") - mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) - mocker.patch.object(io, "open") - - @pytest.fixture - def mock_open(self, mocker): - return mocker.patch.object(io, "open") - - @pytest.mark.it( - "Extracts the server verification certificate from the file indicated by the EdgeModuleCACertificateFile environment variable" - ) - def test_read_server_verification_cert( - self, mocker, client_class, edge_local_debug_environment, mock_open - ): - mock_file_handle = mock_open.return_value.__enter__.return_value - mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) - client_class.create_from_edge_environment() - assert mock_open.call_count == 1 - assert mock_open.call_args == mocker.call( - edge_local_debug_environment["EdgeModuleCACertificateFile"], mode="r" - ) - assert mock_file_handle.read.call_count == 1 - - @pytest.mark.it( - "Uses Edge local debug environment variables to create a SymmetricKeyAuthenticationProvider (with server verification cert)" - ) - def test_auth_provider_creation( - self, mocker, client_class, edge_local_debug_environment, mock_open - ): - expected_cert = mock_open.return_value.__enter__.return_value.read.return_value - mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) - mock_auth_parse = mocker.patch( - "azure.iot.device.iothub.auth.SymmetricKeyAuthenticationProvider" - ).parse - - client_class.create_from_edge_environment() - - assert mock_auth_parse.call_count == 1 - assert mock_auth_parse.call_args == mocker.call( - edge_local_debug_environment["EdgeHubConnectionString"] - ) - assert mock_auth_parse.return_value.server_verification_cert == expected_cert - - @pytest.mark.it( - "Only uses Edge local debug variables if no Edge container variables are present in the environment" - ) - def test_auth_provider_and_pipeline_hybrid_env( - self, - mocker, - client_class, - edge_container_environment, - edge_local_debug_environment, - mock_open, - ): - # This test verifies that with a hybrid environment, the auth provider will always be - # an IoTEdgeAuthenticationProvider, even if local debug variables are present - hybrid_environment = merge_dicts(edge_container_environment, edge_local_debug_environment) - mocker.patch.dict(os.environ, hybrid_environment, clear=True) - mock_edge_auth_init = mocker.patch( - "azure.iot.device.iothub.auth.IoTEdgeAuthenticationProvider" - ) - mock_sk_auth_parse = mocker.patch( - "azure.iot.device.iothub.auth.SymmetricKeyAuthenticationProvider" - ).parse - - client_class.create_from_edge_environment() - - assert mock_edge_auth_init.call_count == 1 - assert mock_sk_auth_parse.call_count == 0 # we did NOT use SK auth - assert mock_edge_auth_init.call_args == mocker.call( - hostname=edge_container_environment["IOTEDGE_IOTHUBHOSTNAME"], - device_id=edge_container_environment["IOTEDGE_DEVICEID"], - module_id=edge_container_environment["IOTEDGE_MODULEID"], - gateway_hostname=edge_container_environment["IOTEDGE_GATEWAYHOSTNAME"], - module_generation_id=edge_container_environment["IOTEDGE_MODULEGENERATIONID"], - workload_uri=edge_container_environment["IOTEDGE_WORKLOADURI"], - api_version=edge_container_environment["IOTEDGE_APIVERSION"], - ) - - @pytest.mark.it( - "Uses the SymmetricKeyAuthenticationProvider to create an MQTTPipeline and an HTTPPipeline" - ) - def test_pipeline_creation(self, mocker, client_class, edge_local_debug_environment, mock_open): - mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) - mock_auth = mocker.patch( - "azure.iot.device.iothub.auth.SymmetricKeyAuthenticationProvider" - ).parse.return_value - mock_config_init = mocker.patch("azure.iot.device.iothub.pipeline.IoTHubPipelineConfig") - mock_mqtt_pipeline_init = mocker.patch("azure.iot.device.iothub.pipeline.MQTTPipeline") - mock_http_pipeline_init = mocker.patch("azure.iot.device.iothub.pipeline.HTTPPipeline") - - client_class.create_from_edge_environment() - - assert mock_mqtt_pipeline_init.call_count == 1 - assert mock_mqtt_pipeline_init.call_args == mocker.call( - mock_auth, mock_config_init.return_value - ) - assert mock_http_pipeline_init.call_count == 1 - assert mock_http_pipeline_init.call_args == mocker.call( - mock_auth, mock_config_init.return_value - ) - - @pytest.mark.it("Uses the MQTTPipeline and the HTTPPipeline to instantiate the client") - def test_client_instantiation( - self, mocker, client_class, edge_local_debug_environment, mock_open - ): - mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) - mock_mqtt_pipeline = mocker.patch( - "azure.iot.device.iothub.pipeline.MQTTPipeline" - ).return_value - mock_http_pipeline = mocker.patch( - "azure.iot.device.iothub.pipeline.HTTPPipeline" - ).return_value - spy_init = mocker.spy(client_class, "__init__") - - client_class.create_from_edge_environment() - - assert spy_init.call_count == 1 - assert spy_init.call_args == mocker.call(mocker.ANY, mock_mqtt_pipeline, mock_http_pipeline) - - @pytest.mark.it("Returns the instantiated client") - def test_returns_client(self, mocker, client_class, edge_local_debug_environment, mock_open): - mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) - - client = client_class.create_from_edge_environment() - - assert isinstance(client, client_class) - - @pytest.mark.it("Raises OSError if the environment is missing required variables") - @pytest.mark.parametrize( - "missing_env_var", ["EdgeHubConnectionString", "EdgeModuleCACertificateFile"] - ) - def test_bad_environment( - self, mocker, client_class, edge_local_debug_environment, missing_env_var, mock_open - ): - # Remove a variable from the fixture - del edge_local_debug_environment[missing_env_var] - mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) - - with pytest.raises(OSError): - client_class.create_from_edge_environment() - - # TODO: If auth package was refactored to use ConnectionString class, tests from that - # class would increase the coverage here. - @pytest.mark.it( - "Raises ValueError if the connection string in the EdgeHubConnectionString environment variable is invalid" - ) - @pytest.mark.parametrize( - "bad_cs", - [ - pytest.param("not-a-connection-string", id="Garbage string"), - pytest.param("", id="Empty string"), - pytest.param( - "HostName=Invalid;DeviceId=Invalid;ModuleId=Invalid;SharedAccessKey=Invalid;GatewayHostName=Invalid", - id="Malformed Connection String", - marks=pytest.mark.xfail(reason="Bug in pipeline + need for auth refactor"), # TODO - ), - ], - ) - def test_bad_connection_string( - self, mocker, client_class, edge_local_debug_environment, bad_cs, mock_open - ): - edge_local_debug_environment["EdgeHubConnectionString"] = bad_cs - mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) - - with pytest.raises(ValueError): - client_class.create_from_edge_environment() - - @pytest.mark.it( - "Raises ValueError if the filepath in the EdgeModuleCACertificateFile environment variable is invalid" - ) - def test_bad_filepath(self, mocker, client_class, edge_local_debug_environment, mock_open): - # To make tests compatible with Python 2 & 3, redfine errors - try: - FileNotFoundError # noqa: F823 - except NameError: - FileNotFoundError = IOError - - mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) - my_fnf_error = FileNotFoundError() - mock_open.side_effect = my_fnf_error - with pytest.raises(ValueError) as e_info: - client_class.create_from_edge_environment() - assert e_info.value.__cause__ is my_fnf_error - - @pytest.mark.it( - "Raises ValueError if the file referenced by the filepath in the EdgeModuleCACertificateFile environment variable cannot be opened" - ) - def test_bad_file_io(self, mocker, client_class, edge_local_debug_environment, mock_open): - # Raise a different error in Python 2 vs 3 - if six.PY2: - error = IOError() - else: - error = OSError() - mocker.patch.dict(os.environ, edge_local_debug_environment, clear=True) - mock_open.side_effect = error - with pytest.raises(ValueError) as e_info: - client_class.create_from_edge_environment() - assert e_info.value.__cause__ is error + pass @pytest.mark.describe("IoTHubModuleClient (Synchronous) - .create_from_x509_certificate()") class TestIoTHubModuleClientCreateFromX509Certificate( - IoTHubModuleClientTestsConfig, SharedClientCreateMethodUserOptionTests + IoTHubModuleClientTestsConfig, SharedIoTHubModuleClientCreateFromX509CertificateTests ): - hostname = "durmstranginstitute.farend" - device_id = "MySnitch" - module_id = "Charms" - - @pytest.fixture - def client_create_method(self, client_class): - """Provides the specific create method for use in universal tests""" - return client_class.create_from_x509_certificate - - @pytest.fixture - def create_method_args(self, x509): - """Provides the specific create method args for use in universal tests""" - return [x509, self.hostname, self.device_id, self.module_id] - - @pytest.mark.it("Uses the provided arguments to create a X509AuthenticationProvider") - def test_auth_provider_creation(self, mocker, client_class, x509): - mock_auth_init = mocker.patch("azure.iot.device.iothub.auth.X509AuthenticationProvider") - - client_class.create_from_x509_certificate( - x509=x509, hostname=self.hostname, device_id=self.device_id, module_id=self.module_id - ) - - assert mock_auth_init.call_count == 1 - assert mock_auth_init.call_args == mocker.call( - x509=x509, hostname=self.hostname, device_id=self.device_id, module_id=self.module_id - ) - - @pytest.mark.it("Uses the X509AuthenticationProvider to create an MQTTPipeline") - def test_pipeline_creation(self, mocker, client_class, x509, mock_mqtt_pipeline_init): - mock_auth = mocker.patch( - "azure.iot.device.iothub.auth.X509AuthenticationProvider" - ).return_value - - mock_config_init = mocker.patch("azure.iot.device.iothub.pipeline.IoTHubPipelineConfig") - - client_class.create_from_x509_certificate( - x509=x509, hostname=self.hostname, device_id=self.device_id, module_id=self.module_id - ) - - assert mock_mqtt_pipeline_init.call_count == 1 - assert mock_mqtt_pipeline_init.call_args == mocker.call( - mock_auth, mock_config_init.return_value - ) - - @pytest.mark.it("Uses the MQTTPipeline to instantiate the client") - def test_client_instantiation(self, mocker, client_class, x509): - mock_pipeline = mocker.patch("azure.iot.device.iothub.pipeline.MQTTPipeline").return_value - mock_pipeline_http = mocker.patch( - "azure.iot.device.iothub.pipeline.HTTPPipeline" - ).return_value - - spy_init = mocker.spy(client_class, "__init__") - - client_class.create_from_x509_certificate( - x509=x509, hostname=self.hostname, device_id=self.device_id, module_id=self.module_id - ) - - assert spy_init.call_count == 1 - assert spy_init.call_args == mocker.call(mocker.ANY, mock_pipeline, mock_pipeline_http) - - @pytest.mark.it("Returns the instantiated client") - def test_returns_client(self, mocker, client_class, x509): - client = client_class.create_from_x509_certificate( - x509=x509, hostname=self.hostname, device_id=self.device_id, module_id=self.module_id - ) - - assert isinstance(client, client_class) + pass @pytest.mark.describe("IoTHubModuleClient (Synchronous) - .connect()") @@ -2202,22 +1310,22 @@ class TestIoTHubNModuleClientSendD2CMessage( @pytest.mark.describe("IoTHubModuleClient (Synchronous) - .send_message_to_output()") class TestIoTHubModuleClientSendToOutput(IoTHubModuleClientTestsConfig, WaitsForEventCompletion): - @pytest.mark.it("Begins a 'send_output_event' pipeline operation") + @pytest.mark.it("Begins a 'send_output_message' pipeline operation") def test_calls_pipeline_send_message_to_output(self, client, mqtt_pipeline, message): output_name = "some_output" client.send_message_to_output(message, output_name) - assert mqtt_pipeline.send_output_event.call_count == 1 - assert mqtt_pipeline.send_output_event.call_args[0][0] is message + assert mqtt_pipeline.send_output_message.call_count == 1 + assert mqtt_pipeline.send_output_message.call_args[0][0] is message assert message.output_name == output_name @pytest.mark.it( - "Waits for the completion of the 'send_output_event' pipeline operation before returning" + "Waits for the completion of the 'send_output_message' pipeline operation before returning" ) def test_waits_for_pipeline_op_completion( self, mocker, client_manual_cb, mqtt_pipeline_manual_cb, message ): self.add_event_completion_checks( - mocker=mocker, pipeline_function=mqtt_pipeline_manual_cb.send_output_event + mocker=mocker, pipeline_function=mqtt_pipeline_manual_cb.send_output_message ) output_name = "some_output" client_manual_cb.send_message_to_output(message, output_name) @@ -2263,7 +1371,7 @@ def test_raises_error_on_pipeline_op_error( my_pipeline_error = pipeline_error() self.add_event_completion_checks( mocker=mocker, - pipeline_function=mqtt_pipeline_manual_cb.send_output_event, + pipeline_function=mqtt_pipeline_manual_cb.send_output_message, kwargs={"error": my_pipeline_error}, ) output_name = "some_output" @@ -2290,8 +1398,8 @@ def test_send_message_to_output_calls_pipeline_wraps_data_in_message( ): output_name = "some_output" client.send_message_to_output(message_input, output_name) - assert mqtt_pipeline.send_output_event.call_count == 1 - sent_message = mqtt_pipeline.send_output_event.call_args[0][0] + assert mqtt_pipeline.send_output_message.call_count == 1 + sent_message = mqtt_pipeline.send_output_message.call_args[0][0] assert isinstance(sent_message, Message) assert sent_message.data == message_input @@ -2303,7 +1411,7 @@ def test_raises_error_when_message_to_output_data_greater_than_256(self, client, with pytest.raises(ValueError) as e_info: client.send_message_to_output(message, output_name) assert "256 KB" in e_info.value.args[0] - assert mqtt_pipeline.send_output_event.call_count == 0 + assert mqtt_pipeline.send_output_message.call_count == 0 @pytest.mark.it("Raises error when message size is greater than 256 KB") def test_raises_error_when_message_to_output_size_greater_than_256(self, client, mqtt_pipeline): @@ -2314,7 +1422,7 @@ def test_raises_error_when_message_to_output_size_greater_than_256(self, client, with pytest.raises(ValueError) as e_info: client.send_message_to_output(message, output_name) assert "256 KB" in e_info.value.args[0] - assert mqtt_pipeline.send_output_event.call_count == 0 + assert mqtt_pipeline.send_output_message.call_count == 0 @pytest.mark.it("Does not raises error when message data size is equal to 256 KB") def test_raises_error_when_message_to_output_data_equal_to_256(self, client, mqtt_pipeline): @@ -2329,8 +1437,8 @@ def test_raises_error_when_message_to_output_data_equal_to_256(self, client, mqt client.send_message_to_output(message, output_name) - assert mqtt_pipeline.send_output_event.call_count == 1 - sent_message = mqtt_pipeline.send_output_event.call_args[0][0] + assert mqtt_pipeline.send_output_message.call_count == 1 + sent_message = mqtt_pipeline.send_output_message.call_args[0][0] assert isinstance(sent_message, Message) assert sent_message.data == data_input @@ -2551,15 +1659,6 @@ def test_raises_error_on_pipeline_op_error( @pytest.mark.describe("IoTHubModule (Synchronous) - PROPERTY .connected") class TestIoTHubModuleClientPROPERTYConnected( - IoTHubModuleClientTestsConfig, SharedClientPROPERTYConnectedTests + IoTHubModuleClientTestsConfig, SharedIoTHubClientPROPERTYConnectedTests ): pass - - -#################### -# HELPER FUNCTIONS # -#################### -def merge_dicts(d1, d2): - d3 = d1.copy() - d3.update(d2) - return d3 diff --git a/azure-iot-device/tests/provisioning/aio/__init__.py b/azure-iot-device/tests/provisioning/aio/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/azure-iot-device/tests/provisioning/aio/test_async_provisioning_device_client.py b/azure-iot-device/tests/provisioning/aio/test_async_provisioning_device_client.py index adaa315d6..9a654e87b 100644 --- a/azure-iot-device/tests/provisioning/aio/test_async_provisioning_device_client.py +++ b/azure-iot-device/tests/provisioning/aio/test_async_provisioning_device_client.py @@ -8,34 +8,21 @@ from azure.iot.device.provisioning.aio.async_provisioning_device_client import ( ProvisioningDeviceClient, ) -from azure.iot.device.provisioning.models.registration_result import ( - RegistrationResult, - RegistrationState, -) -from azure.iot.device.provisioning import security, pipeline -from azure.iot.device.common.models.x509 import X509 +from azure.iot.device.provisioning import pipeline from azure.iot.device.common import async_adapter import asyncio from azure.iot.device.iothub.pipeline import exceptions as pipeline_exceptions from azure.iot.device import exceptions as client_exceptions +from ..shared_client_tests import ( + SharedProvisioningClientInstantiationTests, + SharedProvisioningClientCreateFromSymmetricKeyTests, + SharedProvisioningClientCreateFromX509CertificateTests, +) + logging.basicConfig(level=logging.DEBUG) pytestmark = pytest.mark.asyncio -fake_symmetric_key = "Zm9vYmFy" -fake_registration_id = "MyPensieve" -fake_id_scope = "Enchanted0000Ceiling7898" -fake_provisioning_host = "hogwarts.com" -fake_x509_cert_file_value = "fantastic_beasts" -fake_x509_cert_key_file = "where_to_find_them" -fake_pass_phrase = "alohomora" -fake_status = "flying" -fake_sub_status = "FlyingOnHippogriff" -fake_operation_id = "quidditch_world_cup" -fake_request_id = "request_1234" -fake_device_id = "MyNimbus2000" -fake_assigned_hub = "Dumbledore'sArmy" - async def create_completed_future(result=None): f = asyncio.Future() @@ -43,273 +30,40 @@ async def create_completed_future(result=None): return f -@pytest.fixture -def registration_result(): - registration_state = RegistrationState(fake_device_id, fake_assigned_hub, fake_sub_status) - return RegistrationResult(fake_operation_id, fake_status, registration_state) - - -@pytest.fixture -def x509(): - return X509(fake_x509_cert_file_value, fake_x509_cert_key_file, fake_pass_phrase) - - -@pytest.fixture(autouse=True) -def provisioning_pipeline(mocker): - return mocker.MagicMock(wraps=FakeProvisioningPipeline()) - - -class FakeProvisioningPipeline: - def __init__(self): - self.responses_enabled = {} - - def connect(self, callback): - callback() - - def disconnect(self, callback): - callback() - - def enable_responses(self, callback): - callback() - - def register(self, payload, callback): - callback(result={}) - - -# automatically mock the pipeline for all tests in this file -@pytest.fixture(autouse=True) -def mock_pipeline_init(mocker): - return mocker.patch("azure.iot.device.provisioning.pipeline.ProvisioningPipeline") - - -class SharedClientCreateMethodUserOptionTests(object): - @pytest.mark.it( - "Sets the 'websockets' user option parameter on the PipelineConfig, if provided" - ) - async def test_websockets_option( - self, mocker, client_create_method, create_method_args, mock_pipeline_init - ): - client_create_method(*create_method_args, websockets=True) - - # Get configuration object - assert mock_pipeline_init.call_count == 1 - config = mock_pipeline_init.call_args[0][1] - - assert config.websockets - - # TODO: Show that input in the wrong format is formatted to the correct one. This test exists - # in the ProvisioningPipelineConfig object already, but we do not currently show that this is felt - # from the API level. - @pytest.mark.it("Sets the 'cipher' user option parameter on the PipelineConfig, if provided") - async def test_cipher_option( - self, mocker, client_create_method, create_method_args, mock_pipeline_init - ): - - cipher = "DHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA:ECDHE-ECDSA-AES128-GCM-SHA256" - client_create_method(*create_method_args, cipher=cipher) - - # Get configuration object - assert mock_pipeline_init.call_count == 1 - config = mock_pipeline_init.call_args[0][1] - - assert config.cipher == cipher - - @pytest.mark.it("Raises a TypeError if an invalid user option parameter is provided") - async def test_invalid_option( - self, mocker, client_create_method, create_method_args, mock_pipeline_init - ): - with pytest.raises(TypeError): - client_create_method(*create_method_args, invalid_option="some_value") - - @pytest.mark.it("Sets default user options if none are provided") - async def test_default_options( - self, mocker, client_create_method, create_method_args, mock_pipeline_init - ): - mock_config = mocker.patch( - "azure.iot.device.provisioning.pipeline.ProvisioningPipelineConfig" - ) - client_create_method(*create_method_args) - - # Pipeline Config was instantiated with default arguments - assert mock_config.call_count == 1 - expected_kwargs = {} - assert mock_config.call_args == mocker.call(**expected_kwargs) - - # This default config was used for the protocol pipeline - assert mock_pipeline_init.call_count == 1 - assert mock_pipeline_init.call_args[0][1] == mock_config.return_value - - -@pytest.mark.describe("ProvisioningDeviceClient - Instantiation") -class TestClientInstantiation(object): - @pytest.mark.it( - "Stores the ProvisioningPipeline from the 'provisioning_pipeline' parameter in the '_provisioning_pipeline' attribute" - ) - async def test_sets_provisioning_pipeline(self, provisioning_pipeline): - client = ProvisioningDeviceClient(provisioning_pipeline) - - assert client._provisioning_pipeline is provisioning_pipeline - - @pytest.mark.it( - "Instantiates with the initial value of the '_provisioning_payload' attribute set to None" - ) - async def test_payload(self, provisioning_pipeline): - client = ProvisioningDeviceClient(provisioning_pipeline) - - assert client._provisioning_payload is None - +class ProvisioningClientTestsConfig(object): + """Defines fixtures for asynchronous ProvisioningDeviceClient tests""" -@pytest.mark.describe("ProvisioningDeviceClient - .create_from_symmetric_key()") -class TestClientCreateFromSymmetricKey(SharedClientCreateMethodUserOptionTests): @pytest.fixture - async def client_create_method(self): - return ProvisioningDeviceClient.create_from_symmetric_key + def client_class(self): + return ProvisioningDeviceClient @pytest.fixture - async def create_method_args(self): - return [fake_provisioning_host, fake_registration_id, fake_id_scope, fake_symmetric_key] - - @pytest.mark.it("Creates a SymmetricKeySecurityClient using the given parameters") - async def test_security_client(self, mocker): - spy_sec_client = mocker.spy(security, "SymmetricKeySecurityClient") - - ProvisioningDeviceClient.create_from_symmetric_key( - provisioning_host=fake_provisioning_host, - registration_id=fake_registration_id, - id_scope=fake_id_scope, - symmetric_key=fake_symmetric_key, - ) + def client(self, provisioning_pipeline): + return ProvisioningDeviceClient(provisioning_pipeline) - assert spy_sec_client.call_count == 1 - assert spy_sec_client.call_args == mocker.call( - provisioning_host=fake_provisioning_host, - registration_id=fake_registration_id, - id_scope=fake_id_scope, - symmetric_key=fake_symmetric_key, - ) - @pytest.mark.it( - "Uses the SymmetricKeySecurityClient object and the ProvisioningPipelineConfig object to create a ProvisioningPipeline" - ) - async def test_pipeline(self, mocker, mock_pipeline_init): - # Note that the details of how the pipeline config is set up are covered in the - # SharedClientCreateMethodUserOptionTests - mock_pipeline_config = mocker.patch.object( - pipeline, "ProvisioningPipelineConfig" - ).return_value - mock_sec_client = mocker.patch.object(security, "SymmetricKeySecurityClient").return_value - - ProvisioningDeviceClient.create_from_symmetric_key( - provisioning_host=fake_provisioning_host, - registration_id=fake_registration_id, - id_scope=fake_id_scope, - symmetric_key=fake_symmetric_key, - ) +@pytest.mark.describe("ProvisioningDeviceClient (Async) - Instantiation") +class TestProvisioningClientInstantiation( + ProvisioningClientTestsConfig, SharedProvisioningClientInstantiationTests +): + pass - assert mock_pipeline_init.call_count == 1 - assert mock_pipeline_init.call_args == mocker.call(mock_sec_client, mock_pipeline_config) - @pytest.mark.it("Uses the ProvisioningPipeline to instantiate the client") - async def test_client_creation(self, mocker, mock_pipeline_init): - spy_client_init = mocker.spy(ProvisioningDeviceClient, "__init__") +@pytest.mark.describe("ProvisioningDeviceClient (Async) - .create_from_symmetric_key()") +class TestProvisioningClientCreateFromSymmetricKey( + ProvisioningClientTestsConfig, SharedProvisioningClientCreateFromSymmetricKeyTests +): + pass - ProvisioningDeviceClient.create_from_symmetric_key( - provisioning_host=fake_provisioning_host, - registration_id=fake_registration_id, - id_scope=fake_id_scope, - symmetric_key=fake_symmetric_key, - ) - - assert spy_client_init.call_count == 1 - assert spy_client_init.call_args == mocker.call(mocker.ANY, mock_pipeline_init.return_value) - - @pytest.mark.it("Returns the instantiated client") - async def test_returns_client(self, mocker): - client = ProvisioningDeviceClient.create_from_symmetric_key( - provisioning_host=fake_provisioning_host, - registration_id=fake_registration_id, - id_scope=fake_id_scope, - symmetric_key=fake_symmetric_key, - ) - assert isinstance(client, ProvisioningDeviceClient) - -@pytest.mark.describe("ProvisioningDeviceClient - .create_from_x509_certificate()") -class TestClientCreateFromX509Certificate(SharedClientCreateMethodUserOptionTests): - @pytest.fixture - def client_create_method(self): - return ProvisioningDeviceClient.create_from_x509_certificate - - @pytest.fixture - def create_method_args(self, x509): - return [fake_provisioning_host, fake_registration_id, fake_id_scope, x509] - - @pytest.mark.it("Creates an X509SecurityClient using the given parameters") - async def test_security_client(self, mocker, x509): - spy_sec_client = mocker.spy(security, "X509SecurityClient") - - ProvisioningDeviceClient.create_from_x509_certificate( - provisioning_host=fake_provisioning_host, - registration_id=fake_registration_id, - id_scope=fake_id_scope, - x509=x509, - ) - - assert spy_sec_client.call_count == 1 - assert spy_sec_client.call_args == mocker.call( - provisioning_host=fake_provisioning_host, - registration_id=fake_registration_id, - id_scope=fake_id_scope, - x509=x509, - ) - - @pytest.mark.it( - "Uses the X509SecurityClient object and the ProvisioningPipelineConfig object to create a ProvisioningPipeline" - ) - async def test_pipeline(self, mocker, mock_pipeline_init, x509): - # Note that the details of how the pipeline config is set up are covered in the - # SharedClientCreateMethodUserOptionTests - mock_pipeline_config = mocker.patch.object( - pipeline, "ProvisioningPipelineConfig" - ).return_value - mock_sec_client = mocker.patch.object(security, "X509SecurityClient").return_value - - ProvisioningDeviceClient.create_from_x509_certificate( - provisioning_host=fake_provisioning_host, - registration_id=fake_registration_id, - id_scope=fake_id_scope, - x509=x509, - ) - - assert mock_pipeline_init.call_count == 1 - assert mock_pipeline_init.call_args == mocker.call(mock_sec_client, mock_pipeline_config) - - @pytest.mark.it("Uses the ProvisioningPipeline to instantiate the client") - async def test_client_creation(self, mocker, mock_pipeline_init, x509): - spy_client_init = mocker.spy(ProvisioningDeviceClient, "__init__") - - ProvisioningDeviceClient.create_from_x509_certificate( - provisioning_host=fake_provisioning_host, - registration_id=fake_registration_id, - id_scope=fake_id_scope, - x509=x509, - ) - - assert spy_client_init.call_count == 1 - assert spy_client_init.call_args == mocker.call(mocker.ANY, mock_pipeline_init.return_value) - - @pytest.mark.it("Returns the instantiated client") - async def test_returns_client(self, mocker, x509): - client = ProvisioningDeviceClient.create_from_x509_certificate( - provisioning_host=fake_provisioning_host, - registration_id=fake_registration_id, - id_scope=fake_id_scope, - x509=x509, - ) - assert isinstance(client, ProvisioningDeviceClient) +@pytest.mark.describe("ProvisioningDeviceClient (Async) - .create_from_x509_certificate()") +class TestProvisioningClientCreateFromX509Certificate( + ProvisioningClientTestsConfig, SharedProvisioningClientCreateFromX509CertificateTests +): + pass -@pytest.mark.describe("ProvisioningDeviceClient - .register()") +@pytest.mark.describe("ProvisioningDeviceClient (Async) - .register()") class TestClientRegister(object): @pytest.mark.it("Implicitly enables responses from provisioning service if not already enabled") async def test_enables_provisioning_only_if_not_already_enabled( @@ -437,7 +191,7 @@ def register_complete_failure_callback(payload, callback): assert provisioning_pipeline.register.call_count == 1 -@pytest.mark.describe("ProvisioningDeviceClient - .set_provisioning_payload()") +@pytest.mark.describe("ProvisioningDeviceClient (Async) - .set_provisioning_payload()") class TestClientProvisioningPayload(object): @pytest.mark.it("Sets the payload on the provisioning payload attribute") @pytest.mark.parametrize( diff --git a/azure-iot-device/tests/provisioning/conftest.py b/azure-iot-device/tests/provisioning/conftest.py index 88a5b1924..c0585667b 100644 --- a/azure-iot-device/tests/provisioning/conftest.py +++ b/azure-iot-device/tests/provisioning/conftest.py @@ -6,9 +6,11 @@ import sys import pytest -from azure.iot.device.provisioning.models.registration_result import ( - RegistrationResult, - RegistrationState, +from .shared_client_fixtures import ( + mock_pipeline_init, + provisioning_pipeline, + registration_result, + x509, ) collect_ignore = [] diff --git a/azure-iot-device/tests/provisioning/pipeline/helpers.py b/azure-iot-device/tests/provisioning/pipeline/helpers.py index dccc21973..bc61db593 100644 --- a/azure-iot-device/tests/provisioning/pipeline/helpers.py +++ b/azure-iot-device/tests/provisioning/pipeline/helpers.py @@ -6,9 +6,6 @@ from azure.iot.device.provisioning.pipeline import pipeline_ops_provisioning all_provisioning_ops = [ - pipeline_ops_provisioning.SetSymmetricKeySecurityClientOperation, - pipeline_ops_provisioning.SetX509SecurityClientOperation, - pipeline_ops_provisioning.SetProvisioningClientConnectionArgsOperation, pipeline_ops_provisioning.RegisterOperation, pipeline_ops_provisioning.PollStatusOperation, ] diff --git a/azure-iot-device/tests/provisioning/pipeline/test_config.py b/azure-iot-device/tests/provisioning/pipeline/test_config.py index 6ceb64814..40723252b 100644 --- a/azure-iot-device/tests/provisioning/pipeline/test_config.py +++ b/azure-iot-device/tests/provisioning/pipeline/test_config.py @@ -5,9 +5,13 @@ # -------------------------------------------------------------------------- import pytest import logging -from tests.common.pipeline.pipeline_config_test import PipelineConfigInstantiationTestBase +from tests.common.pipeline.config_test import PipelineConfigInstantiationTestBase from azure.iot.device.provisioning.pipeline.config import ProvisioningPipelineConfig +hostname = "hostname.some-domain.net" +registration_id = "registration_id" +id_scope = "id_scope" + @pytest.mark.describe("ProvisioningPipelineConfig - Instantiation") class TestProvisioningPipelineConfigInstantiation(PipelineConfigInstantiationTestBase): @@ -15,3 +19,31 @@ class TestProvisioningPipelineConfigInstantiation(PipelineConfigInstantiationTes def config_cls(self): # This fixture is needed for the parent class return ProvisioningPipelineConfig + + @pytest.fixture + def required_kwargs(self): + # This fixture is needed for the parent class + return {"hostname": hostname, "registration_id": registration_id, "id_scope": id_scope} + + # The parent class defines the auth mechanism fixtures (sastoken, x509). + # For the sake of ease of testing, we will assume sastoken is being used unless + # there is a strict need to do something else. + # It does not matter which is used for the purposes of these tests. + + @pytest.mark.it( + "Instantiates with the 'registration_id' attribute set to the provided 'registration_id' paramameter" + ) + def test_registration_id_set(self, sastoken): + config = ProvisioningPipelineConfig( + hostname=hostname, registration_id=registration_id, id_scope=id_scope, sastoken=sastoken + ) + assert config.registration_id == registration_id + + @pytest.mark.it( + "Instantiates with the 'id_scope' attribute set to the provided 'id_scope' parameter" + ) + def test_id_scope_set(self, sastoken): + config = ProvisioningPipelineConfig( + hostname=hostname, registration_id=registration_id, id_scope=id_scope, sastoken=sastoken + ) + assert config.id_scope == id_scope diff --git a/azure-iot-device/tests/provisioning/pipeline/test_provisioning_pipeline.py b/azure-iot-device/tests/provisioning/pipeline/test_mqtt_pipeline.py similarity index 64% rename from azure-iot-device/tests/provisioning/pipeline/test_provisioning_pipeline.py rename to azure-iot-device/tests/provisioning/pipeline/test_mqtt_pipeline.py index 345aed403..7e0ce4645 100644 --- a/azure-iot-device/tests/provisioning/pipeline/test_provisioning_pipeline.py +++ b/azure-iot-device/tests/provisioning/pipeline/test_mqtt_pipeline.py @@ -7,9 +7,7 @@ import pytest import logging from azure.iot.device.common.models import X509 -from azure.iot.device.provisioning.security.sk_security_client import SymmetricKeySecurityClient -from azure.iot.device.provisioning.security.x509_security_client import X509SecurityClient -from azure.iot.device.provisioning.pipeline.provisioning_pipeline import ProvisioningPipeline +from azure.iot.device.provisioning.pipeline.mqtt_pipeline import MQTTPipeline from tests.common.pipeline import helpers import json from azure.iot.device.provisioning.pipeline import constant as dps_constants @@ -31,91 +29,23 @@ feature = dps_constants.REGISTER -fake_symmetric_key = "Zm9vYmFy" -fake_registration_id = "MyPensieve" -fake_id_scope = "Enchanted0000Ceiling7898" -fake_provisioning_host = "beauxbatons.academy-net" -fake_device_id = "elder_wand" -fake_registration_id = "registered_remembrall" -fake_provisioning_host = "hogwarts.com" -fake_id_scope = "weasley_wizard_wheezes" -fake_ca_cert = "fake_certificate" -fake_sas_token = "horcrux_token" -fake_security_client = "secure_via_muffliato" -fake_request_id = "fake_request_1234" -fake_mqtt_payload = "hello hogwarts" -fake_register_publish_payload = '{{"payload": {json_payload}, "registrationId": "{reg_id}"}}'.format( - reg_id=fake_registration_id, json_payload=json.dumps(fake_mqtt_payload) -) -fake_operation_id = "fake_operation_9876" -fake_sub_unsub_topic = "$dps/registrations/res/#" -fake_x509_cert_file = "fantastic_beasts" -fake_x509_cert_key_file = "where_to_find_them" -fake_pass_phrase = "alohomora" -fake_registration_result = "fake_result" -fake_request_payload = "fake_request_payload" - - def mock_x509(): - return X509(fake_x509_cert_file, fake_x509_cert_key_file, fake_pass_phrase) - - -different_security_clients = [ - ( - SymmetricKeySecurityClient, - { - "provisioning_host": fake_provisioning_host, - "registration_id": fake_registration_id, - "id_scope": fake_id_scope, - "symmetric_key": fake_symmetric_key, - }, - ), - ( - X509SecurityClient, - { - "provisioning_host": fake_provisioning_host, - "registration_id": fake_registration_id, - "id_scope": fake_id_scope, - "x509": mock_x509(), - }, - ), -] - - -def assert_for_symmetric_key(password): - assert password is not None - assert "SharedAccessSignature" in password - assert "skn=registration" in password - assert fake_id_scope in password - assert fake_registration_id in password - - -def assert_for_client_x509(x509): - assert x509 is not None - assert x509.certificate_file == fake_x509_cert_file - assert x509.key_file == fake_x509_cert_key_file - assert x509.pass_phrase == fake_pass_phrase - - -@pytest.fixture( - scope="function", - params=different_security_clients, - ids=[x[0].__name__ for x in different_security_clients], -) -def input_security_client(request): - sec_client_class = request.param[0] - init_kwargs = request.param[1] - return sec_client_class(**init_kwargs) + return X509( + cert_file="fantastic_beasts", key_file="where_to_find_them", pass_phrase="alohomora" + ) @pytest.fixture def pipeline_configuration(mocker): - return mocker.MagicMock() + mock_config = mocker.MagicMock() + mock_config.sastoken.ttl = 1232 # set for compat + mock_config.registration_id = "MyPensieve" + return mock_config @pytest.fixture -def pipeline(mocker, input_security_client, pipeline_configuration): - pipeline = ProvisioningPipeline(input_security_client, pipeline_configuration) +def pipeline(mocker, pipeline_configuration): + pipeline = MQTTPipeline(pipeline_configuration) mocker.patch.object(pipeline._pipeline, "run_op") return pipeline @@ -124,45 +54,45 @@ def pipeline(mocker, input_security_client, pipeline_configuration): @pytest.fixture(autouse=True) def mock_mqtt_transport(mocker): return mocker.patch( - "azure.iot.device.provisioning.pipeline.provisioning_pipeline.pipeline_stages_mqtt.MQTTTransport" - ).return_value + "azure.iot.device.common.pipeline.pipeline_stages_mqtt.MQTTTransport", autospec=True + ) -@pytest.mark.describe("ProvisioningPipeline - Instantiation") -class TestProvisioningPipelineInstantiation(object): +@pytest.mark.describe("MQTTPipeline - Instantiation") +class TestMQTTPipelineInstantiation(object): @pytest.mark.it("Begins tracking the enabled/disabled status of responses") - def test_features(self, input_security_client, pipeline_configuration): - pipeline = ProvisioningPipeline(input_security_client, pipeline_configuration) + def test_features(self, pipeline_configuration): + pipeline = MQTTPipeline(pipeline_configuration) pipeline.responses_enabled[feature] # No assertion required - if this doesn't raise a KeyError, it is a success @pytest.mark.it("Marks responses as disabled") - def test_features_disabled(self, input_security_client, pipeline_configuration): - pipeline = ProvisioningPipeline(input_security_client, pipeline_configuration) + def test_features_disabled(self, pipeline_configuration): + pipeline = MQTTPipeline(pipeline_configuration) assert not pipeline.responses_enabled[feature] @pytest.mark.it("Sets all handlers to an initial value of None") - def test_handlers_set_to_none(self, input_security_client, pipeline_configuration): - pipeline = ProvisioningPipeline(input_security_client, pipeline_configuration) + def test_handlers_set_to_none(self, pipeline_configuration): + pipeline = MQTTPipeline(pipeline_configuration) assert pipeline.on_connected is None assert pipeline.on_disconnected is None assert pipeline.on_message_received is None @pytest.mark.it("Configures the pipeline to trigger handlers in response to external events") - def test_handlers_configured(self, input_security_client, pipeline_configuration): - pipeline = ProvisioningPipeline(input_security_client, pipeline_configuration) + def test_handlers_configured(self, pipeline_configuration): + pipeline = MQTTPipeline(pipeline_configuration) assert pipeline._pipeline.on_pipeline_event_handler is not None assert pipeline._pipeline.on_connected_handler is not None assert pipeline._pipeline.on_disconnected_handler is not None @pytest.mark.it("Configures the pipeline with a series of PipelineStages") - def test_pipeline_configuration(self, input_security_client, pipeline_configuration): - pipeline = ProvisioningPipeline(input_security_client, pipeline_configuration) + def test_pipeline_configuration(self, pipeline_configuration): + pipeline = MQTTPipeline(pipeline_configuration) curr_stage = pipeline._pipeline expected_stage_order = [ pipeline_stages_base.PipelineRootStage, - pipeline_stages_provisioning.UseSecurityClientStage, + pipeline_stages_base.SasTokenRenewalStage, pipeline_stages_provisioning.RegistrationStage, pipeline_stages_provisioning.PollingStatusStage, pipeline_stages_base.CoordinateRequestAndResponseStage, @@ -184,38 +114,24 @@ def test_pipeline_configuration(self, input_security_client, pipeline_configurat # Assert there are no more additional stages assert curr_stage is None - # TODO: revist these tests after auth revision - # They are too tied to auth types (and there's too much variance in auths to effectively test) - # Ideally ProvisioningPipeline is entirely insulated from any auth differential logic (and module/device distinctions) - # In the meantime, we are using a device auth with connection string to stand in for generic SAS auth - # and device auth with X509 certs to stand in for generic X509 auth - @pytest.mark.it( - "Runs a Set SecurityClient Operation with the provided SecurityClient on the pipeline" - ) - def test_security_client_success(self, mocker, input_security_client, pipeline_configuration): + @pytest.mark.it("Runs an InitializePipelineOperation on the pipeline") + def test_init_pipeline(self, mocker, pipeline_configuration): mocker.spy(pipeline_stages_base.PipelineRootStage, "run_op") - pipeline = ProvisioningPipeline(input_security_client, pipeline_configuration) + + pipeline = MQTTPipeline(pipeline_configuration) op = pipeline._pipeline.run_op.call_args[0][1] assert pipeline._pipeline.run_op.call_count == 1 - if isinstance(input_security_client, X509SecurityClient): - assert isinstance(op, pipeline_ops_provisioning.SetX509SecurityClientOperation) - else: - assert isinstance(op, pipeline_ops_provisioning.SetSymmetricKeySecurityClientOperation) - assert op.security_client is input_security_client + assert isinstance(op, pipeline_ops_base.InitializePipelineOperation) @pytest.mark.it( - "Raises exceptions that occurred in execution upon unsuccessful completion of the Set SecurityClient Operation" + "Raises exceptions that occurred in execution upon unsuccessful completion of the InitializePipelineOperation" ) - def test_security_client_failure( - self, mocker, input_security_client, arbitrary_exception, pipeline_configuration - ): + def test_init_pipeline_failure(self, mocker, arbitrary_exception, pipeline_configuration): old_run_op = pipeline_stages_base.PipelineRootStage._run_op def fail_set_security_client(self, op): - if isinstance(input_security_client, X509SecurityClient) or isinstance( - input_security_client, SymmetricKeySecurityClient - ): + if isinstance(op, pipeline_ops_base.InitializePipelineOperation): op.complete(error=arbitrary_exception) else: old_run_op(self, op) @@ -228,12 +144,12 @@ def fail_set_security_client(self, op): ) with pytest.raises(arbitrary_exception.__class__) as e_info: - ProvisioningPipeline(input_security_client, pipeline_configuration) + MQTTPipeline(pipeline_configuration) assert e_info.value is arbitrary_exception -@pytest.mark.describe("ProvisioningPipeline - .connect()") -class TestProvisioningPipelineConnect(object): +@pytest.mark.describe("MQTTPipeline - .connect()") +class TestMQTTPipelineConnect(object): @pytest.mark.it("Runs a ConnectOperation on the pipeline") def test_runs_op(self, pipeline, mocker): cb = mocker.MagicMock() @@ -272,8 +188,8 @@ def test_op_fail(self, mocker, pipeline, arbitrary_exception): assert cb.call_args == mocker.call(error=arbitrary_exception) -@pytest.mark.describe("ProvisioningPipeline - .disconnect()") -class TestProvisioningPipelineDisconnect(object): +@pytest.mark.describe("MQTTPipeline - .disconnect()") +class TestMQTTPipelineDisconnect(object): @pytest.mark.it("Runs a DisconnectOperation on the pipeline") def test_runs_op(self, pipeline, mocker): pipeline.disconnect(callback=mocker.MagicMock()) @@ -311,7 +227,7 @@ def test_op_fail(self, mocker, pipeline, arbitrary_exception): assert cb.call_args == mocker.call(error=arbitrary_exception) -@pytest.mark.describe("ProvisioningPipeline - .register()") +@pytest.mark.describe("MQTTPipeline - .register()") class TestSendRegister(object): @pytest.mark.it("Runs a RegisterOperation on the pipeline") def test_runs_op(self, pipeline, mocker): @@ -320,11 +236,12 @@ def test_runs_op(self, pipeline, mocker): assert pipeline._pipeline.run_op.call_count == 1 op = pipeline._pipeline.run_op.call_args[0][0] assert isinstance(op, pipeline_ops_provisioning.RegisterOperation) - assert op.registration_id == fake_registration_id + assert op.registration_id == pipeline._pipeline.pipeline_configuration.registration_id @pytest.mark.it("passes the payload parameter as request_payload on the RegistrationRequest") def test_sets_request_payload(self, pipeline, mocker): cb = mocker.MagicMock() + fake_request_payload = "fake_request_payload" pipeline.register(payload=fake_request_payload, callback=cb) op = pipeline._pipeline.run_op.call_args[0][0] assert op.request_payload is fake_request_payload @@ -350,6 +267,7 @@ def test_op_success_with_callback(self, mocker, pipeline): # Trigger op completion op = pipeline._pipeline.run_op.call_args[0][0] + fake_registration_result = "fake_result" op.registration_result = fake_registration_result op.complete(error=None) @@ -365,6 +283,7 @@ def test_op_fail(self, mocker, pipeline, arbitrary_exception): pipeline.register(callback=cb) op = pipeline._pipeline.run_op.call_args[0][0] + fake_registration_result = "fake_result" op.registration_result = fake_registration_result op.complete(error=arbitrary_exception) @@ -372,7 +291,7 @@ def test_op_fail(self, mocker, pipeline, arbitrary_exception): assert cb.call_args == mocker.call(error=arbitrary_exception, result=None) -@pytest.mark.describe("ProvisioningPipeline - .enable_responses()") +@pytest.mark.describe("MQTTPipeline - .enable_responses()") class TestEnable(object): @pytest.mark.it("Marks the feature as enabled") def test_mark_feature_enabled(self, pipeline, mocker): diff --git a/azure-iot-device/tests/provisioning/pipeline/test_pipeline_ops_provisioning.py b/azure-iot-device/tests/provisioning/pipeline/test_pipeline_ops_provisioning.py index b2bf5e0c1..b24af72c1 100644 --- a/azure-iot-device/tests/provisioning/pipeline/test_pipeline_ops_provisioning.py +++ b/azure-iot-device/tests/provisioning/pipeline/test_pipeline_ops_provisioning.py @@ -14,139 +14,6 @@ pytestmark = pytest.mark.usefixtures("fake_pipeline_thread") -class SetSymmetricKeySecurityClientOperationTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_ops_provisioning.SetSymmetricKeySecurityClientOperation - - @pytest.fixture - def init_kwargs(self, mocker): - kwargs = {"security_client": mocker.MagicMock(), "callback": mocker.MagicMock()} - return kwargs - - -class SetSymmetricKeySecurityClientOperationInstantiationTests( - SetSymmetricKeySecurityClientOperationTestConfig -): - @pytest.mark.it( - "Initializes 'security_client' attribute with the provided 'security_client' parameter" - ) - def test_security_client(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.security_client is init_kwargs["security_client"] - - -pipeline_ops_test.add_operation_tests( - test_module=this_module, - op_class_under_test=pipeline_ops_provisioning.SetSymmetricKeySecurityClientOperation, - op_test_config_class=SetSymmetricKeySecurityClientOperationTestConfig, - extended_op_instantiation_test_class=SetSymmetricKeySecurityClientOperationInstantiationTests, -) - - -class SetX509SecurityClientOperationTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_ops_provisioning.SetX509SecurityClientOperation - - @pytest.fixture - def init_kwargs(self, mocker): - kwargs = {"security_client": mocker.MagicMock(), "callback": mocker.MagicMock()} - return kwargs - - -class SetX509SecurityClientOperationInstantiationTests(SetX509SecurityClientOperationTestConfig): - @pytest.mark.it( - "Initializes 'security_client' attribute with the provided 'security_client' parameter" - ) - def test_security_client(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.security_client is init_kwargs["security_client"] - - -pipeline_ops_test.add_operation_tests( - test_module=this_module, - op_class_under_test=pipeline_ops_provisioning.SetX509SecurityClientOperation, - op_test_config_class=SetX509SecurityClientOperationTestConfig, - extended_op_instantiation_test_class=SetX509SecurityClientOperationInstantiationTests, -) - - -class SetProvisioningClientConnectionArgsOperationTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_ops_provisioning.SetProvisioningClientConnectionArgsOperation - - @pytest.fixture - def init_kwargs(self, mocker): - kwargs = { - "provisioning_host": "some_provisioning_host", - "registration_id": "some_registration_id", - "id_scope": "some_id_scope", - "callback": mocker.MagicMock(), - "client_cert": "some_client_cert", - "sas_token": "some_sas_token", - } - return kwargs - - -class SetProvisioningClientConnectionArgsOperationInstantiationTests( - SetProvisioningClientConnectionArgsOperationTestConfig -): - @pytest.mark.it( - "Initializes 'provisioning_host' attribute with the provided 'provisioning_host' parameter" - ) - def test_provisioning_host(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.provisioning_host is init_kwargs["provisioning_host"] - - @pytest.mark.it( - "Initializes 'registration_id' attribute with the provided 'registration_id' parameter" - ) - def test_registration_id(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.registration_id is init_kwargs["registration_id"] - - @pytest.mark.it("Initializes 'id_scope' attribute with the provided 'id_scope' parameter") - def test_id_scope(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.id_scope is init_kwargs["id_scope"] - - @pytest.mark.it("Initializes 'client_cert' attribute with the provided 'client_cert' parameter") - def test_client_cert(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.client_cert is init_kwargs["client_cert"] - - @pytest.mark.it( - "Initializes 'client_cert' attribute to None if no 'client_cert' parameter is provided" - ) - def test_client_cert_default(self, cls_type, init_kwargs): - del init_kwargs["client_cert"] - op = cls_type(**init_kwargs) - assert op.client_cert is None - - @pytest.mark.it("Initializes 'sas_token' attribute with the provided 'sas_token' parameter") - def test_sas_token(self, cls_type, init_kwargs): - op = cls_type(**init_kwargs) - assert op.sas_token is init_kwargs["sas_token"] - - @pytest.mark.it( - "Initializes 'sas_token' attribute to None if no 'sas_token' parameter is provided" - ) - def test_sas_token_default(self, cls_type, init_kwargs): - del init_kwargs["sas_token"] - op = cls_type(**init_kwargs) - assert op.sas_token is None - - -pipeline_ops_test.add_operation_tests( - test_module=this_module, - op_class_under_test=pipeline_ops_provisioning.SetProvisioningClientConnectionArgsOperation, - op_test_config_class=SetProvisioningClientConnectionArgsOperationTestConfig, - extended_op_instantiation_test_class=SetProvisioningClientConnectionArgsOperationInstantiationTests, -) - - class RegisterOperationTestConfig(object): @pytest.fixture def cls_type(self): diff --git a/azure-iot-device/tests/provisioning/pipeline/test_pipeline_stages_provisioning.py b/azure-iot-device/tests/provisioning/pipeline/test_pipeline_stages_provisioning.py index 09b9466f7..14b2751ac 100644 --- a/azure-iot-device/tests/provisioning/pipeline/test_pipeline_stages_provisioning.py +++ b/azure-iot-device/tests/provisioning/pipeline/test_pipeline_stages_provisioning.py @@ -5,31 +5,17 @@ # -------------------------------------------------------------------------- import logging import pytest -import functools import sys -from azure.iot.device.common.models.x509 import X509 -from azure.iot.device.provisioning.security.sk_security_client import SymmetricKeySecurityClient -from azure.iot.device.provisioning.security.x509_security_client import X509SecurityClient +import json +import datetime from azure.iot.device.provisioning.pipeline import ( pipeline_stages_provisioning, pipeline_ops_provisioning, ) -from azure.iot.device.common.pipeline import pipeline_ops_base - -from tests.common.pipeline.helpers import ( - assert_callback_succeeded, - assert_callback_failed, - all_common_ops, - all_common_events, - all_except, - StageTestBase, -) -from azure.iot.device.common.pipeline import pipeline_events_base -from tests.provisioning.pipeline.helpers import all_provisioning_ops +from azure.iot.device.common.pipeline import pipeline_ops_base, pipeline_events_base from tests.common.pipeline import pipeline_stage_test from azure.iot.device.exceptions import ServiceError -import json -import datetime + from azure.iot.device.provisioning.models.registration_result import ( RegistrationResult, RegistrationState, @@ -40,14 +26,8 @@ import threading logging.basicConfig(level=logging.DEBUG) - this_module = sys.modules[__name__] - - -# Make it look like we're always running inside pipeline threads -@pytest.fixture(autouse=True) -def apply_fake_pipeline_thread(fake_pipeline_thread): - pass +pytestmark = pytest.mark.usefixtures("fake_pipeline_thread") fake_device_id = "elder_wand" @@ -71,58 +51,6 @@ def apply_fake_pipeline_thread(fake_pipeline_thread): fake_pass_phrase = "alohomora" -pipeline_stage_test.add_base_pipeline_stage_tests_old( - cls=pipeline_stages_provisioning.UseSecurityClientStage, - module=this_module, - all_ops=all_common_ops + all_provisioning_ops, - handled_ops=[ - pipeline_ops_provisioning.SetSymmetricKeySecurityClientOperation, - pipeline_ops_provisioning.SetX509SecurityClientOperation, - ], - all_events=all_common_events, - handled_events=[], -) - - -pipeline_stage_test.add_base_pipeline_stage_tests_old( - cls=pipeline_stages_provisioning.RegistrationStage, - module=this_module, - all_ops=all_common_ops + all_provisioning_ops, - handled_ops=[pipeline_ops_provisioning.RegisterOperation], - all_events=all_common_events, - handled_events=[], -) - - -pipeline_stage_test.add_base_pipeline_stage_tests_old( - cls=pipeline_stages_provisioning.PollingStatusStage, - module=this_module, - all_ops=all_common_ops + all_provisioning_ops, - handled_ops=[pipeline_ops_provisioning.PollStatusOperation], - all_events=all_common_events, - handled_events=[], -) - - -def make_mock_x509_security_client(): - mock_x509 = X509(fake_x509_cert_file, fake_x509_cert_key_file, fake_pass_phrase) - return X509SecurityClient( - provisioning_host=fake_provisioning_host, - registration_id=fake_registration_id, - id_scope=fake_id_scope, - x509=mock_x509, - ) - - -def make_mock_symmetric_security_client(): - return SymmetricKeySecurityClient( - provisioning_host=fake_provisioning_host, - registration_id=fake_registration_id, - id_scope=fake_id_scope, - symmetric_key=fake_symmetric_key, - ) - - class FakeRegistrationResult(object): def __init__(self, operation_id, status, state): self.operationId = operation_id @@ -171,158 +99,6 @@ def op_error(request, arbitrary_exception): return None -############################# -# USE SECURITY CLIENT STAGE # -############################# - - -class UseSecurityClientStageTestConfig(object): - @pytest.fixture - def cls_type(self): - return pipeline_stages_provisioning.UseSecurityClientStage - - @pytest.fixture - def init_kwargs(self): - return {} - - @pytest.fixture - def stage(self, mocker, cls_type, init_kwargs): - stage = cls_type(**init_kwargs) - stage.send_op_down = mocker.MagicMock() - stage.send_event_up = mocker.MagicMock() - return stage - - -pipeline_stage_test.add_base_pipeline_stage_tests( - test_module=this_module, - stage_class_under_test=pipeline_stages_provisioning.UseSecurityClientStage, - stage_test_config_class=UseSecurityClientStageTestConfig, -) - - -@pytest.mark.describe( - "UseSecurityClientStage - .run_op() -- Called with SetSymmetricKeySecurityClientOperation" -) -class TestUseSecurityClientStageRunOpWithSetSymmetricKeySecurityClientOperation( - StageRunOpTestBase, UseSecurityClientStageTestConfig -): - @pytest.fixture - def op(self, mocker): - security_client = SymmetricKeySecurityClient( - provisioning_host="hogwarts.com", - registration_id="registered_remembrall", - id_scope="weasley_wizard_wheezes", - symmetric_key="Zm9vYmFy", - ) - security_client.get_current_sas_token = mocker.MagicMock() - return pipeline_ops_provisioning.SetSymmetricKeySecurityClientOperation( - security_client=security_client, callback=mocker.MagicMock() - ) - - @pytest.mark.it( - "Sends a new SetProvisioningClientConnectionArgsOperation op down the pipeline, containing connection info from the op's security client" - ) - def test_send_new_op_down(self, mocker, op, stage): - stage.run_op(op) - - # A SetProvisioningClientConnectionArgsOperation has been sent down the pipeline - stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance( - new_op, pipeline_ops_provisioning.SetProvisioningClientConnectionArgsOperation - ) - - # The SetProvisioningClientConnectionArgsOperation has details from the security client - assert new_op.provisioning_host == op.security_client.provisioning_host - assert new_op.registration_id == op.security_client.registration_id - assert new_op.id_scope == op.security_client.id_scope - assert new_op.sas_token == op.security_client.get_current_sas_token.return_value - assert new_op.client_cert is None - - @pytest.mark.it( - "Completes the original SetSymmetricKeySecurityClientOperation with the same status as the new SetProvisioningClientConnectionArgsOperation, if the new SetProvisioningClientConnectionArgsOperation is completed" - ) - def test_new_op_completes_success(self, mocker, op, stage, op_error): - stage.run_op(op) - stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance( - new_op, pipeline_ops_provisioning.SetProvisioningClientConnectionArgsOperation - ) - - assert not op.completed - assert not new_op.completed - - new_op.complete(error=op_error) - - assert new_op.completed - assert new_op.error is op_error - assert op.completed - assert op.error is op_error - - -@pytest.mark.describe( - "UseSecurityClientStage - .run_op() -- Called with SetX509SecurityClientOperation" -) -class TestUseSecurityClientStageRunOpWithSetX509SecurityClientOperation( - StageRunOpTestBase, UseSecurityClientStageTestConfig -): - @pytest.fixture - def op(self, mocker): - x509 = X509(cert_file="fake_cert.txt", key_file="fake_key.txt", pass_phrase="alohomora") - security_client = X509SecurityClient( - provisioning_host="hogwarts.com", - registration_id="registered_remembrall", - id_scope="weasley_wizard_wheezes", - x509=x509, - ) - security_client.get_x509_certificate = mocker.MagicMock() - return pipeline_ops_provisioning.SetX509SecurityClientOperation( - security_client=security_client, callback=mocker.MagicMock() - ) - - @pytest.mark.it( - "Sends a new SetProvisioningClientConnectionArgsOperation op down the pipeline, containing connection info from the op's security client" - ) - def test_send_new_op_down(self, mocker, op, stage): - stage.run_op(op) - - # A SetProvisioningClientConnectionArgsOperation has been sent down the pipeline - stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance( - new_op, pipeline_ops_provisioning.SetProvisioningClientConnectionArgsOperation - ) - - # The SetProvisioningClientConnectionArgsOperation has details from the security client - assert new_op.provisioning_host == op.security_client.provisioning_host - assert new_op.registration_id == op.security_client.registration_id - assert new_op.id_scope == op.security_client.id_scope - assert new_op.client_cert == op.security_client.get_x509_certificate.return_value - assert new_op.sas_token is None - - @pytest.mark.it( - "Completes the original SetX509SecurityClientOperation with the same status as the new SetProvisioningClientConnectionArgsOperation, if the new SetProvisioningClientConnectionArgsOperation is completed" - ) - def test_new_op_completes_success(self, mocker, op, stage, op_error): - stage.run_op(op) - stage.send_op_down.call_count == 1 - new_op = stage.send_op_down.call_args[0][0] - assert isinstance( - new_op, pipeline_ops_provisioning.SetProvisioningClientConnectionArgsOperation - ) - - assert not op.completed - assert not new_op.completed - - new_op.complete(error=op_error) - - assert new_op.completed - assert new_op.error is op_error - assert op.completed - assert op.error is op_error - - ############################### # REGISTRATION STAGE # ############################### diff --git a/azure-iot-device/tests/provisioning/pipeline/test_pipeline_stages_provisioning_mqtt.py b/azure-iot-device/tests/provisioning/pipeline/test_pipeline_stages_provisioning_mqtt.py index 873341f81..1fe99f560 100644 --- a/azure-iot-device/tests/provisioning/pipeline/test_pipeline_stages_provisioning_mqtt.py +++ b/azure-iot-device/tests/provisioning/pipeline/test_pipeline_stages_provisioning_mqtt.py @@ -6,440 +6,475 @@ import logging import pytest import sys +import json import six.moves.urllib as urllib -from azure.iot.device import constant +from azure.iot.device import constant as pkg_constant from azure.iot.device.common.pipeline import ( pipeline_ops_base, pipeline_stages_base, pipeline_ops_mqtt, pipeline_events_mqtt, pipeline_events_base, + pipeline_exceptions, ) from azure.iot.device.provisioning.pipeline import ( + config, pipeline_ops_provisioning, pipeline_stages_provisioning_mqtt, ) -from tests.common.pipeline.helpers import ( - assert_callback_failed, - assert_callback_succeeded, - all_common_ops, - all_common_events, - all_except, - StageTestBase, -) -from tests.provisioning.pipeline.helpers import all_provisioning_ops -from tests.common.pipeline import pipeline_stage_test -import json from azure.iot.device.provisioning.pipeline import constant as pipeline_constant -from azure.iot.device.product_info import ProductInfo +from azure.iot.device import user_agent +from tests.common.pipeline.helpers import StageRunOpTestBase, StageHandlePipelineEventTestBase +from tests.common.pipeline import pipeline_stage_test logging.basicConfig(level=logging.DEBUG) - this_module = sys.modules[__name__] +pytestmark = pytest.mark.usefixtures("fake_pipeline_thread") -# This fixture makes it look like all test in this file tests are running -# inside the pipeline thread. Because this is an autouse fixture, we -# manually add it to the individual test.py files that need it. If, -# instead, we had added it to some conftest.py, it would be applied to -# every tests in every file and we don't want that. -@pytest.fixture(autouse=True) -def apply_fake_pipeline_thread(fake_pipeline_thread): - pass - - -fake_device_id = "elder_wand" -fake_registration_id = "registered_remembrall" -fake_provisioning_host = "hogwarts.com" -fake_id_scope = "weasley_wizard_wheezes" -fake_sas_token = "horcrux_token" -fake_security_client = "secure_via_muffliato" -fake_request_id = "fake_request_1234" -fake_mqtt_payload = "hello hogwarts" -fake_operation_id = "fake_operation_9876" -fake_client_cert = "fake_client_cert" - -invalid_feature_name = "__invalid_feature_name__" -unmatched_mqtt_topic = "__unmatched_mqtt_topic__" - -fake_response_topic = "$dps/registrations/res/200/?$rid={}".format(fake_request_id) - -ops_handled_by_this_stage = [ - pipeline_ops_provisioning.SetProvisioningClientConnectionArgsOperation, - pipeline_ops_base.RequestOperation, - pipeline_ops_base.EnableFeatureOperation, - pipeline_ops_base.DisableFeatureOperation, -] - -events_handled_by_this_stage = [pipeline_events_mqtt.IncomingMQTTMessageEvent] - -pipeline_stage_test.add_base_pipeline_stage_tests_old( - cls=pipeline_stages_provisioning_mqtt.ProvisioningMQTTTranslationStage, - module=this_module, - all_ops=all_common_ops + all_provisioning_ops, - handled_ops=ops_handled_by_this_stage, - all_events=all_common_events, - handled_events=events_handled_by_this_stage, - extra_initializer_defaults={"action_to_topic": dict}, -) +@pytest.fixture(params=[True, False], ids=["With error", "No error"]) +def op_error(request, arbitrary_exception): + if request.param: + return arbitrary_exception + else: + return None @pytest.fixture -def set_security_client_args(mocker): - op = pipeline_ops_provisioning.SetProvisioningClientConnectionArgsOperation( - provisioning_host=fake_provisioning_host, - registration_id=fake_registration_id, - id_scope=fake_id_scope, - sas_token=fake_sas_token, - client_cert=fake_client_cert, - callback=mocker.MagicMock(), +def mock_mqtt_topic(mocker): + m = mocker.patch( + "azure.iot.device.provisioning.pipeline.pipeline_stages_provisioning_mqtt.mqtt_topic_provisioning" ) - mocker.spy(op, "complete") - return op + return m -class ProvisioningMQTTTranslationStageTestBase(StageTestBase): +class ProvisioningMQTTTranslationStageTestConfig(object): + @pytest.fixture + def cls_type(self): + return pipeline_stages_provisioning_mqtt.ProvisioningMQTTTranslationStage + @pytest.fixture - def stage(self): - return pipeline_stages_provisioning_mqtt.ProvisioningMQTTTranslationStage() + def init_kwargs(self): + return {} @pytest.fixture - def stages_configured(self, stage, stage_base_configuration, set_security_client_args, mocker): - mocker.spy(stage.pipeline_root, "handle_pipeline_event") + def pipeline_config(self, mocker): + # auth type shouldn't matter for this stage, so just give it a fake sastoken for now. + cfg = config.ProvisioningPipelineConfig( + hostname="http://my.hostname", + registration_id="fake_reg_id", + id_scope="fake_id_scope", + sastoken=mocker.MagicMock(), + ) + return cfg - stage.run_op(set_security_client_args) - mocker.resetall() + @pytest.fixture + def stage(self, mocker, cls_type, init_kwargs, pipeline_config): + stage = cls_type(**init_kwargs) + stage.pipeline_root = pipeline_stages_base.PipelineRootStage(pipeline_config) + stage.send_op_down = mocker.MagicMock() + stage.send_event_up = mocker.MagicMock() + return stage + + +pipeline_stage_test.add_base_pipeline_stage_tests( + test_module=this_module, + stage_class_under_test=pipeline_stages_provisioning_mqtt.ProvisioningMQTTTranslationStage, + stage_test_config_class=ProvisioningMQTTTranslationStageTestConfig, +) @pytest.mark.describe( - "ProvisioningMQTTTranslationStage run_op function with SetProvisioningClientConnectionArgsOperation" + "ProvisioningMQTTTranslationStage - .run_op() -- Called with InitializePipelineOperation" ) -class TestProvisioningMQTTTranslationStageWithSetProvisioningClientConnectionArgsOperation( - ProvisioningMQTTTranslationStageTestBase +class TestProvisioningMQTTTranslationStageRunOpWithInitializePipelineOperation( + StageRunOpTestBase, ProvisioningMQTTTranslationStageTestConfig ): - @pytest.mark.it( - "Runs a pipeline_ops_mqtt.SetMQTTConnectionArgsOperation operation on the next stage" - ) - def test_runs_set_connection_args(self, stage, set_security_client_args): - stage.run_op(set_security_client_args) - assert stage.next._run_op.call_count == 1 - new_op = stage.next._run_op.call_args[0][0] - assert isinstance(new_op, pipeline_ops_mqtt.SetMQTTConnectionArgsOperation) + @pytest.fixture + def op(self, mocker): + return pipeline_ops_base.InitializePipelineOperation(callback=mocker.MagicMock()) - @pytest.mark.it( - "Sets SetMQTTConnectionArgsOperation.client_id = SetProvisioningClientConnectionArgsOperation.registration_id" - ) - def test_sets_client_id(self, stage, set_security_client_args): - stage.run_op(set_security_client_args) - new_op = stage.next._run_op.call_args[0][0] - assert new_op.client_id == fake_registration_id + @pytest.mark.it("Derives the MQTT client id, and sets it on the op") + def test_client_id(self, stage, op, pipeline_config): + assert not hasattr(op, "client_id") + stage.run_op(op) - @pytest.mark.it( - "Sets SetMQTTConnectionArgsOperation.hostname = SetProvisioningClientConnectionArgsOperation.provisioning_host" - ) - def test_sets_hostname(self, stage, set_security_client_args): - stage.run_op(set_security_client_args) - new_op = stage.next._run_op.call_args[0][0] - assert new_op.hostname == fake_provisioning_host + assert op.client_id == pipeline_config.registration_id - @pytest.mark.it( - "Sets SetMQTTConnectionArgsOperation.client_cert = SetProvisioningClientConnectionArgsOperation.client_cert" - ) - def test_sets_client_cert(self, stage, set_security_client_args): - stage.run_op(set_security_client_args) - new_op = stage.next._run_op.call_args[0][0] - assert new_op.client_cert == fake_client_cert + @pytest.mark.it("Derives the MQTT username, and sets it on the op") + def test_username(self, stage, op, pipeline_config): + assert not hasattr(op, "username") + stage.run_op(op) + + expected_username = "{id_scope}/registrations/{registration_id}/api-version={api_version}&ClientVersion={user_agent}".format( + id_scope=pipeline_config.id_scope, + registration_id=pipeline_config.registration_id, + api_version=pkg_constant.PROVISIONING_API_VERSION, + user_agent=urllib.parse.quote(user_agent.get_provisioning_user_agent(), safe=""), + ) + assert op.username == expected_username + + @pytest.mark.it("Sends the op down the pipeline") + def test_sends_down(self, mocker, stage, op): + stage.run_op(op) + assert stage.send_op_down.call_count == 1 + assert stage.send_op_down.call_args == mocker.call(op) + + +@pytest.mark.describe( + "ProvisioningMQTTTranslationStage - .run_op() -- Called with RequestOperation (Register Request)" +) +class TestProvisioningMQTTTranslationStageRunOpWithRequestOperationRegister( + StageRunOpTestBase, ProvisioningMQTTTranslationStageTestConfig +): + @pytest.fixture + def op(self, mocker): + return pipeline_ops_base.RequestOperation( + request_type=pipeline_constant.REGISTER, + method="PUT", + resource_location="/", + request_body='{"json": "payload"}', + request_id="fake_request_id", + callback=mocker.MagicMock(), + ) + + @pytest.mark.it("Derives the Provisioning Register Request topic using the op's details") + def test_register_request_topic(self, mocker, stage, op, mock_mqtt_topic): + stage.run_op(op) + + assert mock_mqtt_topic.get_register_topic_for_publish.call_count == 1 + assert mock_mqtt_topic.get_register_topic_for_publish.call_args == mocker.call( + request_id=op.request_id + ) @pytest.mark.it( - "Sets SetMQTTConnectionArgsOperation.sas_token = SetProvisioningClientConnectionArgsOperation.sas_token" + "Sends a new MQTTPublishOperation down the pipeline with the original op's request body and the derived topic string" ) - def test_sets_sas_token(self, stage, set_security_client_args): - stage.run_op(set_security_client_args) - new_op = stage.next._run_op.call_args[0][0] - assert new_op.sas_token == fake_sas_token + def test_sends_mqtt_publish_down(self, mocker, stage, op, mock_mqtt_topic): + stage.run_op(op) + + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_mqtt.MQTTPublishOperation) + assert new_op.topic == mock_mqtt_topic.get_register_topic_for_publish.return_value + assert new_op.payload == op.request_body + + @pytest.mark.it("Completes the original op upon completion of the new MQTTPbulishOperation") + def test_complete_resulting_op(self, stage, op, op_error): + stage.run_op(op) + assert not op.completed + assert op.error is None + + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_mqtt.MQTTPublishOperation) + + new_op.complete(error=op_error) + + assert new_op.completed + assert new_op.error is op_error + assert op.completed + assert op.error is op_error + + +@pytest.mark.describe( + "ProvisioningMQTTTranslationStage - .run_op() -- Called with RequestOperation (Query Request)" +) +class TestProvisioningMQTTTranslationStageRunOpWithRequestOperationQuery( + StageRunOpTestBase, ProvisioningMQTTTranslationStageTestConfig +): + @pytest.fixture + def op(self, mocker): + return pipeline_ops_base.RequestOperation( + request_type=pipeline_constant.QUERY, + method="GET", + resource_location="/", + query_params={"operation_id": "fake_op_id"}, + request_body="some body", + request_id="fake_request_id", + callback=mocker.MagicMock(), + ) + + @pytest.mark.it("Derives the Provisioning Query Request topic using the op's details") + def test_register_request_topic(self, mocker, stage, op, mock_mqtt_topic): + stage.run_op(op) + + assert mock_mqtt_topic.get_query_topic_for_publish.call_count == 1 + assert mock_mqtt_topic.get_query_topic_for_publish.call_args == mocker.call( + request_id=op.request_id, operation_id=op.query_params["operation_id"] + ) @pytest.mark.it( - "Sets MqttConnectionArgsOperation.username = SetProvisioningClientConnectionArgsOperation.{id_scope}/registrations/{registration_id}/api-version={api_version}&ClientVersion={client_version}" + "Sends a new MQTTPublishOperation down the pipeline with the original op's request body and the derived topic string" ) - def test_sets_username(self, stage, set_security_client_args): - stage.run_op(set_security_client_args) - new_op = stage.next._run_op.call_args[0][0] - assert ( - new_op.username - == "{id_scope}/registrations/{registration_id}/api-version={api_version}&ClientVersion={client_version}".format( - id_scope=fake_id_scope, - registration_id=fake_registration_id, - api_version=constant.PROVISIONING_API_VERSION, - client_version=urllib.parse.quote( - ProductInfo.get_provisioning_user_agent(), safe="" - ), - ) + def test_sends_mqtt_publish_down(self, mocker, stage, op, mock_mqtt_topic): + stage.run_op(op) + + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_mqtt.MQTTPublishOperation) + assert new_op.topic == mock_mqtt_topic.get_query_topic_for_publish.return_value + assert new_op.payload == op.request_body + + @pytest.mark.it("Completes the original op upon completion of the new MQTTPbulishOperation") + def test_complete_resulting_op(self, stage, op, op_error): + stage.run_op(op) + assert not op.completed + assert op.error is None + + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_mqtt.MQTTPublishOperation) + + new_op.complete(error=op_error) + + assert new_op.completed + assert new_op.error is op_error + assert op.completed + assert op.error is op_error + + +@pytest.mark.describe( + "ProvisioningMQTTTranslationStage - .run_op() -- Called with RequestOperation (Unsupported Request Type)" +) +class TestProvisioningMQTTTranslationStageRunOpWithRequestOperationUnsupportedType( + StageRunOpTestBase, ProvisioningMQTTTranslationStageTestConfig +): + @pytest.fixture + def op(self, mocker): + return pipeline_ops_base.RequestOperation( + request_type="FAKE_REQUEST_TYPE", + method="GET", + resource_location="/", + request_body="some body", + request_id="fake_request_id", + callback=mocker.MagicMock(), + ) + + @pytest.mark.it("Completes the operation with an OperationError failure") + def test_fail(self, mocker, stage, op): + assert not op.completed + assert op.error is None + + stage.run_op(op) + + assert op.completed + assert isinstance(op.error, pipeline_exceptions.OperationError) + + +@pytest.mark.describe( + "ProvisioningMQTTTranslationStage - .run_op() -- Called with EnableFeatureOperation" +) +class TestProvisioningMQTTTranslationStageRunOpWithEnableFeatureOperation( + StageRunOpTestBase, ProvisioningMQTTTranslationStageTestConfig +): + @pytest.fixture + def op(self, mocker): + return pipeline_ops_base.EnableFeatureOperation( + feature_name=pipeline_constant.REGISTER, callback=mocker.MagicMock() ) @pytest.mark.it( - "Completes the SetSymmetricKeySecurityClientArgs op with error if the pipeline_ops_mqtt.SetMQTTConnectionArgsOperation operation raises an Exception" + "Sends a new MQTTSubscribeOperation down the pipeline, containing the subscription topic for Register, if Register is the feature being enabled" ) - def test_set_connection_args_raises_exception( - self, stage, mocker, arbitrary_exception, set_security_client_args - ): - stage.next._run_op = mocker.Mock(side_effect=arbitrary_exception) - stage.run_op(set_security_client_args) - assert set_security_client_args.complete.call_count == 1 - assert set_security_client_args.complete.call_args == mocker.call(error=arbitrary_exception) + def test_mqtt_subscribe_sent_down(self, mocker, op, stage, mock_mqtt_topic): + stage.run_op(op) + + # Topic was derived as expected + assert mock_mqtt_topic.get_register_topic_for_subscribe.call_count == 1 + assert mock_mqtt_topic.get_register_topic_for_subscribe.call_args == mocker.call() + + # New op was sent down + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_mqtt.MQTTSubscribeOperation) + + # New op has the expected topic + assert new_op.topic == mock_mqtt_topic.get_register_topic_for_subscribe.return_value + + @pytest.mark.it("Completes the original op upon completion of the new MQTTSubscribeOperation") + def test_complete_resulting_op(self, stage, op, op_error): + stage.run_op(op) + assert not op.completed + assert op.error is None + + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + + new_op.complete(error=op_error) + + assert new_op.completed + assert new_op.error is op_error + assert op.completed + assert op.error is op_error @pytest.mark.it( - "Calls the SetSymmetricKeySecurityClientArgs callback with no error if the pipeline_ops_mqtt.SetMQTTConnectionArgsOperation operation succeeds" + "Completes the operation with an OperationError failure if the feature being enabled is of any type other than Register" ) - def test_returns_success_if_set_connection_args_succeeds( - self, stage, mocker, set_security_client_args, next_stage_succeeds - ): - stage.run_op(set_security_client_args) - assert set_security_client_args.complete.call_count == 1 - assert set_security_client_args.complete.call_args == mocker.call(error=None) - - -basic_ops = [ - { - "op_class": pipeline_ops_base.RequestOperation, - "op_init_kwargs": { - "request_id": fake_request_id, - "request_type": pipeline_constant.REGISTER, - "method": "PUT", - "resource_location": "/", - "request_body": "test payload", - }, - "new_op_class": pipeline_ops_mqtt.MQTTPublishOperation, - }, - { - "op_class": pipeline_ops_base.RequestOperation, - "op_init_kwargs": { - "request_id": fake_request_id, - "request_type": pipeline_constant.QUERY, - "method": "GET", - "resource_location": "/", - "query_params": {"operation_id": fake_operation_id}, - "request_body": "test payload", - }, - "new_op_class": pipeline_ops_mqtt.MQTTPublishOperation, - }, - { - "op_class": pipeline_ops_base.EnableFeatureOperation, - "op_init_kwargs": {"feature_name": None}, - "new_op_class": pipeline_ops_mqtt.MQTTSubscribeOperation, - }, - { - "op_class": pipeline_ops_base.DisableFeatureOperation, - "op_init_kwargs": {"feature_name": None}, - "new_op_class": pipeline_ops_mqtt.MQTTUnsubscribeOperation, - }, -] + def test_unsupported_feature(self, stage, op): + op.feature_name = "invalid feature" + assert not op.completed + assert op.error is None + stage.run_op(op) -@pytest.fixture -def op(params, mocker): - op = params["op_class"](callback=mocker.MagicMock(), **params["op_init_kwargs"]) - mocker.spy(op, "complete") - return op + assert op.completed + assert isinstance(op.error, pipeline_exceptions.OperationError) -@pytest.mark.parametrize( - "params", - basic_ops, - ids=["{}->{}".format(x["op_class"].__name__, x["new_op_class"].__name__) for x in basic_ops], +@pytest.mark.describe( + "ProvisioningMQTTTranslationStage - .run_op() -- Called with DisableFeatureOperation" ) -@pytest.mark.describe("ProvisioningMQTTTranslationStage basic operation tests") -class TestProvisioningMQTTTranslationStageBasicOperations(ProvisioningMQTTTranslationStageTestBase): - @pytest.mark.it("Runs an operation on the next stage") - def test_runs_publish(self, params, stage, stages_configured, op): - stage.run_op(op) - new_op = stage.next._run_op.call_args[0][0] - assert isinstance(new_op, params["new_op_class"]) - - @pytest.mark.it("Completes the original op with error if the new_op raises an Exception") - def test_new_op_raises_exception( - self, params, mocker, stage, stages_configured, op, arbitrary_exception - ): - stage.next._run_op = mocker.Mock(side_effect=arbitrary_exception) - stage.run_op(op) - assert op.complete.call_count == 1 - assert op.complete.call_args == mocker.call(error=arbitrary_exception) - - @pytest.mark.it("Allows any BaseExceptions raised from inside new_op to propagate") - def test_new_op_raises_base_exception( - self, params, mocker, stage, stages_configured, op, arbitrary_base_exception - ): - stage.next._run_op = mocker.Mock(side_effect=arbitrary_base_exception) - with pytest.raises(arbitrary_base_exception.__class__) as e_info: - stage.run_op(op) - e_info.value is arbitrary_base_exception - - @pytest.mark.it("Completes the original op with no error if the new_op operation succeeds") - def test_returns_success_if_publish_succeeds( - self, mocker, params, stage, stages_configured, op, next_stage_succeeds - ): - stage.run_op(op) - assert op.complete.call_count == 1 - assert op.complete.call_args == mocker.call(error=None) - - -publish_ops = [ - { - "name": "send register request with no payload", - "op_class": pipeline_ops_base.RequestOperation, - "op_init_kwargs": { - "request_id": fake_request_id, - "request_type": pipeline_constant.REGISTER, - "method": "PUT", - "resource_location": "/", - "request_body": '{{"payload": {json_payload}, "registrationId": "{reg_id}"}}'.format( - reg_id=fake_registration_id, json_payload=json.dumps(None) - ), - }, - "topic": "$dps/registrations/PUT/iotdps-register/?$rid={request_id}".format( - request_id=fake_request_id - ), - "publish_payload": '{{"payload": {json_payload}, "registrationId": "{reg_id}"}}'.format( - reg_id=fake_registration_id, json_payload=json.dumps(None) - ), - }, - { - "name": "send register request with payload", - "op_class": pipeline_ops_base.RequestOperation, - "op_init_kwargs": { - "request_id": fake_request_id, - "request_type": pipeline_constant.REGISTER, - "method": "PUT", - "resource_location": "/", - "request_body": '{{"payload": {json_payload}, "registrationId": "{reg_id}"}}'.format( - reg_id=fake_registration_id, json_payload=json.dumps(fake_mqtt_payload) - ), - }, - "topic": "$dps/registrations/PUT/iotdps-register/?$rid={request_id}".format( - request_id=fake_request_id - ), - "publish_payload": '{{"payload": {json_payload}, "registrationId": "{reg_id}"}}'.format( - reg_id=fake_registration_id, json_payload=json.dumps(fake_mqtt_payload) - ), - }, - { - "name": "send query request", - "op_class": pipeline_ops_base.RequestOperation, - "op_init_kwargs": { - "request_id": fake_request_id, - "query_params": {"operation_id": fake_operation_id}, - "request_type": pipeline_constant.QUERY, - "method": "GET", - "resource_location": "/", - "request_body": fake_mqtt_payload, - }, - "topic": "$dps/registrations/GET/iotdps-get-operationstatus/?$rid={request_id}&operationId={operation_id}".format( - request_id=fake_request_id, operation_id=fake_operation_id - ), - "publish_payload": fake_mqtt_payload, - }, -] - - -@pytest.mark.parametrize("params", publish_ops, ids=[x["name"] for x in publish_ops]) -@pytest.mark.describe("ProvisioningMQTTTranslationStage run_op function for publish operations") -class TestProvisioningMQTTTranslationStageForPublishOps(ProvisioningMQTTTranslationStageTestBase): - @pytest.mark.it("Uses correct registration topic string when publishing") - def test_uses_topic_for(self, stage, stages_configured, params, op): +class TestProvisioningMQTTTranslationStageRunOpWithDisableFeatureOperation( + StageRunOpTestBase, ProvisioningMQTTTranslationStageTestConfig +): + @pytest.fixture + def op(self, mocker): + return pipeline_ops_base.DisableFeatureOperation( + feature_name=pipeline_constant.REGISTER, callback=mocker.MagicMock() + ) + + @pytest.mark.it( + "Sends a new MQTTUnsubscribeOperation down the pipeline, containing the subscription topic for Register, if Register is the feature being disabled" + ) + def test_mqtt_unsubscribe_sent_down(self, mocker, op, stage, mock_mqtt_topic): stage.run_op(op) - new_op = stage.next._run_op.call_args[0][0] - assert new_op.topic == params["topic"] - @pytest.mark.it("Sends correct payload when publishing") - def test_sends_correct_body(self, stage, stages_configured, params, op): + # Topic was derived as expected + assert mock_mqtt_topic.get_register_topic_for_subscribe.call_count == 1 + assert mock_mqtt_topic.get_register_topic_for_subscribe.call_args == mocker.call() + + # New op was sent down + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] + assert isinstance(new_op, pipeline_ops_mqtt.MQTTUnsubscribeOperation) + + # New op has the expected topic + assert new_op.topic == mock_mqtt_topic.get_register_topic_for_subscribe.return_value + + @pytest.mark.it("Completes the original op upon completion of the new MQTTUnsubscribeOperation") + def test_complete_resulting_op(self, stage, op, op_error): stage.run_op(op) - new_op = stage.next._run_op.call_args[0][0] - assert new_op.payload == params["publish_payload"] + assert not op.completed + assert op.error is None + assert stage.send_op_down.call_count == 1 + new_op = stage.send_op_down.call_args[0][0] -sub_unsub_operations = [ - { - "op_class": pipeline_ops_base.EnableFeatureOperation, - "new_op": pipeline_ops_mqtt.MQTTSubscribeOperation, - }, - { - "op_class": pipeline_ops_base.DisableFeatureOperation, - "new_op": pipeline_ops_mqtt.MQTTUnsubscribeOperation, - }, -] + new_op.complete(error=op_error) + assert new_op.completed + assert new_op.error is op_error + assert op.completed + assert op.error is op_error -@pytest.mark.describe( - "ProvisioningMQTTTranslationStage run_op function with EnableFeature operation" -) -class TestProvisioningMQTTTranslationStageWithEnable(ProvisioningMQTTTranslationStageTestBase): - @pytest.mark.parametrize( - "op_parameters", - sub_unsub_operations, - ids=[x["op_class"].__name__ for x in sub_unsub_operations], + @pytest.mark.it( + "Completes the operation with an OperationError failure if the feature being disabled is of any type other than Register" ) - @pytest.mark.it("Gets the correct topic") - def test_converts_feature_name_to_topic(self, mocker, stage, stages_configured, op_parameters): - topic = "$dps/registrations/res/#" - stage.next._run_op = mocker.Mock() + def test_unsupported_feature(self, stage, op): + op.feature_name = "invalid feature" + assert not op.completed + assert op.error is None - op = op_parameters["op_class"](feature_name=None, callback=mocker.MagicMock()) stage.run_op(op) - new_op = stage.next._run_op.call_args[0][0] - assert isinstance(new_op, op_parameters["new_op"]) - assert new_op.topic == topic + + assert op.completed + assert isinstance(op.error, pipeline_exceptions.OperationError) -@pytest.mark.describe("ProvisioningMQTTTranslationStage _handle_pipeline_event") -class TestProvisioningMQTTTranslationStageHandlePipelineEvent( - ProvisioningMQTTTranslationStageTestBase +@pytest.mark.describe( + "IoTHubMQTTTranslationStage - .run_op() -- Called with other arbitrary operation" +) +class TestProvisioningMQTTTranslationStageRunOpWithArbitraryOperation( + StageRunOpTestBase, ProvisioningMQTTTranslationStageTestConfig ): - @pytest.mark.it("Passes up any mqtt messages with topics that aren't matched by this stage") - def test_passes_up_mqtt_message_with_unknown_topic(self, stage, stages_configured, mocker): - event = pipeline_events_mqtt.IncomingMQTTMessageEvent( - topic=unmatched_mqtt_topic, payload=fake_mqtt_payload - ) - stage.handle_pipeline_event(event) - assert stage.previous.handle_pipeline_event.call_count == 1 - assert stage.previous.handle_pipeline_event.call_args == mocker.call(event) + @pytest.fixture + def op(self, arbitrary_op): + return arbitrary_op + @pytest.mark.it("Sends the operation down the pipeline") + def test_sends_op_down(self, mocker, stage, op): + stage.run_op(op) -@pytest.fixture -def dps_response_event(): - return pipeline_events_mqtt.IncomingMQTTMessageEvent( - topic=fake_response_topic, payload=fake_mqtt_payload - ) + assert stage.send_op_down.call_count == 1 + assert stage.send_op_down.call_args == mocker.call(op) -@pytest.mark.describe("ProvisioningMQTTTranslationStage _handle_pipeline_event for response") -class TestProvisioningMQTTConverterHandlePipelineEventRegistrationResponse( - ProvisioningMQTTTranslationStageTestBase +@pytest.mark.it( + "IoTHubMQTTTranslationStage - .handle_pipeline_event() -- Called with IncomingMQTTMessageEvent (DPS Response Topic)" +) +class TestProvisioningMQTTTranslationStageHandlePipelineEventWithIncomingMQTTMessageEventDPSResponseTopic( + StageHandlePipelineEventTestBase, ProvisioningMQTTTranslationStageTestConfig ): + @pytest.fixture + def status(self): + return 200 + + @pytest.fixture + def rid(self): + return "3226c2f7-3d30-425c-b83b-0c34335f8220" + + @pytest.fixture(params=["With retry-after", "No retry-after"]) + def retry_after(self, request): + if request.param == "With retry-after": + return "1234" + else: + return None + + @pytest.fixture + def event(self, status, rid, retry_after): + topic = "$dps/registrations/res/{status}/?$rid={rid}".format(status=status, rid=rid) + if retry_after: + topic = topic + "&retry-after={}".format(retry_after) + return pipeline_events_mqtt.IncomingMQTTMessageEvent(topic=topic, payload=b"some payload") + @pytest.mark.it( - "Converts mqtt message with topic $dps/registrations/res/#/ to registration response event" + "Sends a ResponseEvent up the pipeline containing the original event's payload and values extracted from the topic string" ) - def test_converts_response_topic_to_registration_response_event( - self, mocker, stage, stages_configured, dps_response_event - ): - stage.handle_pipeline_event(dps_response_event) - assert stage.previous.handle_pipeline_event.call_count == 1 - new_event = stage.previous.handle_pipeline_event.call_args[0][0] + def test_response_event(self, event, stage, status, rid, retry_after): + stage.handle_pipeline_event(event) + + assert stage.send_event_up.call_count == 1 + new_event = stage.send_event_up.call_args[0][0] assert isinstance(new_event, pipeline_events_base.ResponseEvent) + assert new_event.status_code == status + assert new_event.request_id == rid + assert new_event.retry_after == retry_after + assert new_event.response_body == event.payload - @pytest.mark.it("Extracts message properties from the mqtt topic for c2d messages") - def test_extracts_some_properties_from_topic( - self, mocker, stage, stages_configured, dps_response_event - ): - stage.handle_pipeline_event(dps_response_event) - new_event = stage.previous.handle_pipeline_event.call_args[0][0] - assert new_event.request_id == fake_request_id - assert new_event.status_code == 200 - - @pytest.mark.it("Passes up other messages") - def test_if_topic_is_not_response(self, mocker, stage, stages_configured): - fake_some_other_topic = "devices/{}/messages/devicebound/".format(fake_device_id) - event = pipeline_events_mqtt.IncomingMQTTMessageEvent( - topic=fake_some_other_topic, payload=fake_mqtt_payload - ) + +@pytest.mark.describe( + "ProvisioningMQTTTranslationStage - .handle_pipeline_event() -- Called with IncomingMQTTMessaveEvent (Unrecognized topic string)" +) +class TestProvisioningMQTTTranslationStageHandlePipelineEventWithIncomingMQTTMessageEventUnknownTopicString( + StageHandlePipelineEventTestBase, ProvisioningMQTTTranslationStageTestConfig +): + @pytest.fixture + def event(self): + topic = "not a real topic" + return pipeline_events_mqtt.IncomingMQTTMessageEvent(topic=topic, payload=b"some payload") + + @pytest.mark.it("Sends the event up the pipeline") + def test_sends_up(self, event, stage): stage.handle_pipeline_event(event) - assert stage.previous.handle_pipeline_event.call_count == 1 - assert stage.previous.handle_pipeline_event.call_args == mocker.call(event) + + assert stage.send_event_up.call_count == 1 + assert stage.send_event_up.call_args[0][0] == event + + +@pytest.mark.describe( + "ProvisioningMQTTTranslationStage - .handle_pipeline_event() -- Called with other arbitrary event" +) +class TestProvisioningMQTTTranslationStageHandlePipelineEventWithArbitraryEvent( + StageHandlePipelineEventTestBase, ProvisioningMQTTTranslationStageTestConfig +): + @pytest.fixture + def event(self, arbitrary_event): + return arbitrary_event + + @pytest.mark.it("Sends the event up the pipeline") + def test_sends_up(self, event, stage): + stage.handle_pipeline_event(event) + + assert stage.send_event_up.call_count == 1 + assert stage.send_event_up.call_args[0][0] == event diff --git a/azure-iot-device/tests/provisioning/security/test_sk_security_client.py b/azure-iot-device/tests/provisioning/security/test_sk_security_client.py deleted file mode 100644 index 7b40824bc..000000000 --- a/azure-iot-device/tests/provisioning/security/test_sk_security_client.py +++ /dev/null @@ -1,52 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import pytest -import logging -from azure.iot.device.provisioning.security.sk_security_client import SymmetricKeySecurityClient - -logging.basicConfig(level=logging.DEBUG) - -fake_symmetric_key = "Zm9vYmFy" -key_name = "registration" -fake_provisioning_host = "beauxbatons.academy-net" -fake_registration_id = "MyPensieve" -module_id = "Divination" -fake_id_scope = "Enchanted0000Ceiling7898" -signature = "IsolemnlySwearThatIamuUptoNogood" -expiry = "1539043658" - - -@pytest.mark.describe("SymmetricKeySecurityClient") -class TestSymmetricKeySecurityClient(object): - @pytest.mark.it("Properties have getters") - def test_properties_are_gettable_after_instantiation_security_client(self): - security_client = SymmetricKeySecurityClient( - fake_provisioning_host, fake_registration_id, fake_id_scope, fake_symmetric_key - ) - assert security_client.provisioning_host == fake_provisioning_host - assert security_client.id_scope == fake_id_scope - assert security_client.registration_id == fake_registration_id - - @pytest.mark.it("Properties do not have setter") - def test_properties_are_not_settable_after_instantiation_security_client(self): - security_client = SymmetricKeySecurityClient( - fake_provisioning_host, fake_registration_id, fake_id_scope, fake_symmetric_key - ) - with pytest.raises(AttributeError, match="can't set attribute"): - security_client.registration_id = "MyNimbus2000" - security_client.id_scope = "WhompingWillow" - security_client.provisioning_host = "hogwarts.com" - - @pytest.mark.it("Can create sas token") - def test_create_sas(self): - security_client = SymmetricKeySecurityClient( - fake_provisioning_host, fake_registration_id, fake_id_scope, fake_symmetric_key - ) - sas_value = security_client.get_current_sas_token() - assert key_name in sas_value - assert fake_registration_id in sas_value - assert fake_id_scope in sas_value diff --git a/azure-iot-device/tests/provisioning/security/test_x509_security_client.py b/azure-iot-device/tests/provisioning/security/test_x509_security_client.py deleted file mode 100644 index 08222838b..000000000 --- a/azure-iot-device/tests/provisioning/security/test_x509_security_client.py +++ /dev/null @@ -1,50 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import pytest -import logging -from azure.iot.device.provisioning.security.x509_security_client import X509SecurityClient -from azure.iot.device.common.models.x509 import X509 - -logging.basicConfig(level=logging.DEBUG) - -fake_provisioning_host = "beauxbatons.academy-net" -fake_registration_id = "MyPensieve" -module_id = "Divination" -fake_id_scope = "Enchanted0000Ceiling7898" -signature = "IsolemnlySwearThatIamuUptoNogood" -expiry = "1539043658" -fake_x509_cert_value = "fantastic_beasts" -fake_x509_cert_key = "where_to_find_them" -fake_pass_phrase = "alohomora" - - -def x509(): - return X509(fake_x509_cert_value, fake_x509_cert_key, fake_pass_phrase) - - -@pytest.mark.describe("X509SecurityClient") -class TestX509SecurityClient(object): - @pytest.mark.it("Properties have getters") - def test_properties_are_gettable_after_instantiation_security_client(self): - x509_cert = x509() - security_client = X509SecurityClient( - fake_provisioning_host, fake_registration_id, fake_id_scope, x509_cert - ) - assert security_client.provisioning_host == fake_provisioning_host - assert security_client.id_scope == fake_id_scope - assert security_client.registration_id == fake_registration_id - assert security_client.get_x509_certificate() == x509_cert - - @pytest.mark.it("Properties do not have setter") - def test_properties_are_not_settable_after_instantiation_security_client(self): - security_client = X509SecurityClient( - fake_provisioning_host, fake_registration_id, fake_id_scope, x509() - ) - with pytest.raises(AttributeError, match="can't set attribute"): - security_client.registration_id = "MyNimbus2000" - security_client.id_scope = "WhompingWillow" - security_client.provisioning_host = "hogwarts.com" diff --git a/azure-iot-device/tests/provisioning/shared_client_fixtures.py b/azure-iot-device/tests/provisioning/shared_client_fixtures.py new file mode 100644 index 000000000..bca1c4360 --- /dev/null +++ b/azure-iot-device/tests/provisioning/shared_client_fixtures.py @@ -0,0 +1,66 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +"""This module contains test fixtures shared between sync/async client tests""" +import pytest +from azure.iot.device.provisioning.models.registration_result import ( + RegistrationResult, + RegistrationState, +) +from azure.iot.device.common.models.x509 import X509 + +"""Constants""" +fake_x509_cert_file_value = "fantastic_beasts" +fake_x509_cert_key_file = "where_to_find_them" +fake_pass_phrase = "alohomora" +fake_status = "flying" +fake_sub_status = "FlyingOnHippogriff" +fake_operation_id = "quidditch_world_cup" +fake_device_id = "MyNimbus2000" +fake_assigned_hub = "Dumbledore'sArmy" + + +"""Pipeline fixtures""" + + +@pytest.fixture +def mock_pipeline_init(mocker): + return mocker.patch("azure.iot.device.provisioning.pipeline.MQTTPipeline") + + +@pytest.fixture(autouse=True) +def provisioning_pipeline(mocker): + return mocker.MagicMock(wraps=FakeProvisioningPipeline()) + + +class FakeProvisioningPipeline: + def __init__(self): + self.responses_enabled = {} + + def connect(self, callback): + callback() + + def disconnect(self, callback): + callback() + + def enable_responses(self, callback): + callback() + + def register(self, payload, callback): + callback(result={}) + + +"""Parameter fixtures""" + + +@pytest.fixture +def registration_result(): + registration_state = RegistrationState(fake_device_id, fake_assigned_hub, fake_sub_status) + return RegistrationResult(fake_operation_id, fake_status, registration_state) + + +@pytest.fixture +def x509(): + return X509(fake_x509_cert_file_value, fake_x509_cert_key_file, fake_pass_phrase) diff --git a/azure-iot-device/tests/provisioning/shared_client_tests.py b/azure-iot-device/tests/provisioning/shared_client_tests.py new file mode 100644 index 000000000..437773935 --- /dev/null +++ b/azure-iot-device/tests/provisioning/shared_client_tests.py @@ -0,0 +1,251 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +"""This module contains tests that are shared between sync/async clients +i.e. tests for things defined in abstract clients""" + +import pytest +import logging +import socks + +from azure.iot.device.common import auth +from azure.iot.device.common.auth import sastoken as st +from azure.iot.device.provisioning.pipeline import ProvisioningPipelineConfig +from azure.iot.device import ProxyOptions + +logging.basicConfig(level=logging.DEBUG) + + +fake_provisioning_host = "hogwarts.com" +fake_registration_id = "MyPensieve" +fake_id_scope = "Enchanted0000Ceiling7898" +fake_symmetric_key = "Zm9vYmFy" + + +class SharedProvisioningClientInstantiationTests(object): + @pytest.mark.it( + "Stores the ProvisioningPipeline from the 'pipeline' parameter in the '_pipeline' attribute" + ) + def test_sets_provisioning_pipeline(self, client_class, provisioning_pipeline): + client = client_class(provisioning_pipeline) + + assert client._pipeline is provisioning_pipeline + + @pytest.mark.it( + "Instantiates with the initial value of the '_provisioning_payload' attribute set to None" + ) + def test_payload(self, client_class, provisioning_pipeline): + client = client_class(provisioning_pipeline) + + assert client._provisioning_payload is None + + +class SharedProvisioningClientCreateMethodUserOptionTests(object): + @pytest.mark.it( + "Sets the 'websockets' user option parameter on the PipelineConfig, if provided" + ) + def test_websockets_option( + self, mocker, client_create_method, create_method_args, mock_pipeline_init + ): + client_create_method(*create_method_args, websockets=True) + + # Get configuration object + assert mock_pipeline_init.call_count == 1 + config = mock_pipeline_init.call_args[0][0] + assert isinstance(config, ProvisioningPipelineConfig) + + assert config.websockets + + # TODO: Show that input in the wrong format is formatted to the correct one. This test exists + # in the ProvisioningPipelineConfig object already, but we do not currently show that this is felt + # from the API level. + @pytest.mark.it("Sets the 'cipher' user option parameter on the PipelineConfig, if provided") + def test_cipher_option(self, client_create_method, create_method_args, mock_pipeline_init): + + cipher = "DHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA:ECDHE-ECDSA-AES128-GCM-SHA256" + client_create_method(*create_method_args, cipher=cipher) + + # Get configuration object + assert mock_pipeline_init.call_count == 1 + config = mock_pipeline_init.call_args[0][0] + assert isinstance(config, ProvisioningPipelineConfig) + + assert config.cipher == cipher + + @pytest.mark.it("Sets the 'proxy_options' user option parameter on the PipelineConfig") + def test_proxy_options(self, client_create_method, create_method_args, mock_pipeline_init): + proxy_options = ProxyOptions(proxy_type=socks.HTTP, proxy_addr="127.0.0.1", proxy_port=8888) + client_create_method(*create_method_args, proxy_options=proxy_options) + + # Get configuration object + assert mock_pipeline_init.call_count == 1 + config = mock_pipeline_init.call_args[0][0] + assert isinstance(config, ProvisioningPipelineConfig) + + assert config.proxy_options is proxy_options + + @pytest.mark.it("Raises a TypeError if an invalid user option parameter is provided") + def test_invalid_option( + self, mocker, client_create_method, create_method_args, mock_pipeline_init + ): + with pytest.raises(TypeError): + client_create_method(*create_method_args, invalid_option="some_value") + + @pytest.mark.it("Sets default user options if none are provided") + def test_default_options( + self, mocker, client_create_method, create_method_args, mock_pipeline_init + ): + client_create_method(*create_method_args) + + # Pipeline uses a ProvisioningPipelineConfig + assert mock_pipeline_init.call_count == 1 + config = mock_pipeline_init.call_args[0][0] + assert isinstance(config, ProvisioningPipelineConfig) + + # ProvisioningPipelineConfig has default options set that were not user-specified + assert config.websockets is False + assert config.cipher == "" + assert config.proxy_options is None + + +@pytest.mark.usefixtures("mock_pipeline_init") +class SharedProvisioningClientCreateFromSymmetricKeyTests( + SharedProvisioningClientCreateMethodUserOptionTests +): + @pytest.fixture + def client_create_method(self, client_class): + return client_class.create_from_symmetric_key + + @pytest.fixture + def create_method_args(self): + return [fake_provisioning_host, fake_registration_id, fake_id_scope, fake_symmetric_key] + + @pytest.mark.it( + "Creates a SasToken that uses a SymmetricKeySigningMechanism, from the values provided in paramaters" + ) + def test_sastoken(self, mocker, client_class): + sksm_mock = mocker.patch.object(auth, "SymmetricKeySigningMechanism") + sastoken_mock = mocker.patch.object(st, "SasToken") + expected_uri = "{id_scope}/registrations/{registration_id}".format( + id_scope=fake_id_scope, registration_id=fake_registration_id + ) + + client_class.create_from_symmetric_key( + provisioning_host=fake_provisioning_host, + registration_id=fake_registration_id, + id_scope=fake_id_scope, + symmetric_key=fake_symmetric_key, + ) + + # SymmetricKeySigningMechanism created using the provided symmetric key + assert sksm_mock.call_count == 1 + assert sksm_mock.call_args == mocker.call(key=fake_symmetric_key) + + # SasToken created with the SymmetricKeySigningMechanism and the expected URI + assert sastoken_mock.call_count == 1 + assert sastoken_mock.call_args == mocker.call(expected_uri, sksm_mock.return_value) + + @pytest.mark.it( + "Creates an MQTT pipeline with a ProvisioningPipelineConfig object containing the SasToken and values provided in the parameters" + ) + def test_pipeline_config(self, mocker, client_class, mock_pipeline_init): + sastoken_mock = mocker.patch.object(st, "SasToken") + + client_class.create_from_symmetric_key( + provisioning_host=fake_provisioning_host, + registration_id=fake_registration_id, + id_scope=fake_id_scope, + symmetric_key=fake_symmetric_key, + ) + + # Verify pipeline was created with a ProvisioningPipelineConfig + assert mock_pipeline_init.call_count == 1 + assert isinstance(mock_pipeline_init.call_args[0][0], ProvisioningPipelineConfig) + + # Verify the ProvisioningPipelineConfig is constructed as expected + config = mock_pipeline_init.call_args[0][0] + assert config.hostname == fake_provisioning_host + assert config.gateway_hostname is None + assert config.registration_id == fake_registration_id + assert config.id_scope == fake_id_scope + assert config.sastoken is sastoken_mock.return_value + + @pytest.mark.it( + "Returns an instance of a ProvisioningDeviceClient using the created MQTT pipeline" + ) + def test_client_returned(self, mocker, client_class, mock_pipeline_init): + client = client_class.create_from_symmetric_key( + provisioning_host=fake_provisioning_host, + registration_id=fake_registration_id, + id_scope=fake_id_scope, + symmetric_key=fake_symmetric_key, + ) + assert isinstance(client, client_class) + assert client._pipeline is mock_pipeline_init.return_value + + @pytest.mark.it("Raises ValueError if a SasToken creation results in failure") + def test_sastoken_failure(self, mocker, client_class): + sastoken_mock = mocker.patch.object(st, "SasToken") + token_err = st.SasTokenError("Some SasToken failure") + sastoken_mock.side_effect = token_err + + with pytest.raises(ValueError) as e_info: + client_class.create_from_symmetric_key( + provisioning_host=fake_provisioning_host, + registration_id=fake_registration_id, + id_scope=fake_id_scope, + symmetric_key=fake_symmetric_key, + ) + assert e_info.value.__cause__ is token_err + + +@pytest.mark.usefixtures("mock_pipeline_init") +class SharedProvisioningClientCreateFromX509CertificateTests( + SharedProvisioningClientCreateMethodUserOptionTests +): + @pytest.fixture + def client_create_method(self, client_class): + return client_class.create_from_x509_certificate + + @pytest.fixture + def create_method_args(self, x509): + return [fake_provisioning_host, fake_registration_id, fake_id_scope, x509] + + @pytest.mark.it( + "Creats MQTT pipeline with a ProvisioningPipelineConfig object containing the X509 and other values provided in parameters" + ) + def test_pipeline_config(self, mocker, client_class, x509, mock_pipeline_init): + client_class.create_from_x509_certificate( + provisioning_host=fake_provisioning_host, + registration_id=fake_registration_id, + id_scope=fake_id_scope, + x509=x509, + ) + + # Verify pipeline created with a ProvisioningPipelineConfig + assert mock_pipeline_init.call_count == 1 + assert isinstance(mock_pipeline_init.call_args[0][0], ProvisioningPipelineConfig) + + # Verify the ProvisioningPipelineConfig is constructed as expected + config = mock_pipeline_init.call_args[0][0] + assert config.hostname == fake_provisioning_host + assert config.gateway_hostname is None + assert config.registration_id == fake_registration_id + assert config.id_scope == fake_id_scope + assert config.x509 is x509 + + @pytest.mark.it( + "Returns an instance of a ProvisioningDeviceClient using the created MQTT pipeline" + ) + def test_client_returned(self, mocker, client_class, x509, mock_pipeline_init): + client = client_class.create_from_x509_certificate( + provisioning_host=fake_provisioning_host, + registration_id=fake_registration_id, + id_scope=fake_id_scope, + x509=x509, + ) + + assert isinstance(client, client_class) + assert client._pipeline is mock_pipeline_init.return_value diff --git a/azure-iot-device/tests/provisioning/test_provisioning_device_client.py b/azure-iot-device/tests/provisioning/test_provisioning_device_client.py deleted file mode 100644 index d1968f59a..000000000 --- a/azure-iot-device/tests/provisioning/test_provisioning_device_client.py +++ /dev/null @@ -1,27 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- - -import pytest -import logging -from azure.iot.device.provisioning.abstract_provisioning_device_client import ( - AbstractProvisioningDeviceClient, -) - -logging.basicConfig(level=logging.DEBUG) - - -class Wizard(object): - def __init__(self, first_name, last_name, dict_of_stuff): - self.first_name = first_name - self.last_name = last_name - self.props = dict_of_stuff - - -@pytest.mark.it("Init of abstract client raises exception") -def test_raises_exception_on_init_of_abstract_client(mocker): - fake_pipeline = mocker.MagicMock() - with pytest.raises(TypeError): - AbstractProvisioningDeviceClient(fake_pipeline) diff --git a/azure-iot-device/tests/provisioning/test_sync_provisioning_device_client.py b/azure-iot-device/tests/provisioning/test_sync_provisioning_device_client.py index 2ca5255cb..9cdc1e6f4 100644 --- a/azure-iot-device/tests/provisioning/test_sync_provisioning_device_client.py +++ b/azure-iot-device/tests/provisioning/test_sync_provisioning_device_client.py @@ -5,301 +5,55 @@ # -------------------------------------------------------------------------- import pytest import logging -from azure.iot.device.common.models.x509 import X509 from azure.iot.device.provisioning.provisioning_device_client import ProvisioningDeviceClient -from azure.iot.device.provisioning.models.registration_result import ( - RegistrationResult, - RegistrationState, -) from azure.iot.device.provisioning.pipeline import exceptions as pipeline_exceptions -from azure.iot.device.provisioning import security, pipeline +from azure.iot.device.provisioning import pipeline import threading from azure.iot.device import exceptions as client_exceptions +from .shared_client_tests import ( + SharedProvisioningClientInstantiationTests, + SharedProvisioningClientCreateFromSymmetricKeyTests, + SharedProvisioningClientCreateFromX509CertificateTests, +) -logging.basicConfig(level=logging.DEBUG) - -fake_symmetric_key = "Zm9vYmFy" -fake_registration_id = "MyPensieve" -fake_id_scope = "Enchanted0000Ceiling7898" -fake_provisioning_host = "hogwarts.com" -fake_x509_cert_file_value = "fantastic_beasts" -fake_x509_cert_key_file = "where_to_find_them" -fake_pass_phrase = "alohomora" -fake_status = "flying" -fake_sub_status = "FlyingOnHippogriff" -fake_operation_id = "quidditch_world_cup" -fake_request_id = "request_1234" -fake_device_id = "MyNimbus2000" -fake_assigned_hub = "Dumbledore'sArmy" - - -@pytest.fixture -def registration_result(): - registration_state = RegistrationState(fake_device_id, fake_assigned_hub, fake_sub_status) - return RegistrationResult(fake_operation_id, fake_status, registration_state) - - -@pytest.fixture -def x509(): - return X509(fake_x509_cert_file_value, fake_x509_cert_key_file, fake_pass_phrase) - - -@pytest.fixture(autouse=True) -def provisioning_pipeline(mocker): - return mocker.MagicMock(wraps=FakeProvisioningPipeline()) - - -class FakeProvisioningPipeline: - def __init__(self): - self.responses_enabled = {} - - def connect(self, callback): - callback() - - def disconnect(self, callback): - callback() - - def enable_responses(self, callback): - callback() - - def register(self, payload, callback): - callback(result={}) - - -# automatically mock the pipeline for all tests in this file -@pytest.fixture(autouse=True) -def mock_pipeline_init(mocker): - return mocker.patch("azure.iot.device.provisioning.pipeline.ProvisioningPipeline") - - -class SharedClientCreateMethodUserOptionTests(object): - @pytest.mark.it( - "Sets the 'websockets' user option parameter on the PipelineConfig, if provided" - ) - def test_websockets_option( - self, mocker, client_create_method, create_method_args, mock_pipeline_init - ): - client_create_method(*create_method_args, websockets=True) - - # Get configuration object - assert mock_pipeline_init.call_count == 1 - config = mock_pipeline_init.call_args[0][1] - - assert config.websockets - - # TODO: Show that input in the wrong format is formatted to the correct one. This test exists - # in the ProvisioningPipelineConfig object already, but we do not currently show that this is felt - # from the API level. - @pytest.mark.it("Sets the 'cipher' user option parameter on the PipelineConfig, if provided") - def test_cipher_option( - self, mocker, client_create_method, create_method_args, mock_pipeline_init - ): - - cipher = "DHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA:ECDHE-ECDSA-AES128-GCM-SHA256" - client_create_method(*create_method_args, cipher=cipher) - - # Get configuration object - assert mock_pipeline_init.call_count == 1 - config = mock_pipeline_init.call_args[0][1] - - assert config.cipher == cipher - - @pytest.mark.it("Raises a TypeError if an invalid user option parameter is provided") - def test_invalid_option( - self, mocker, client_create_method, create_method_args, mock_pipeline_init - ): - with pytest.raises(TypeError): - client_create_method(*create_method_args, invalid_option="some_value") - - @pytest.mark.it("Sets default user options if none are provided") - def test_default_options( - self, mocker, client_create_method, create_method_args, mock_pipeline_init - ): - mock_config = mocker.patch( - "azure.iot.device.provisioning.pipeline.ProvisioningPipelineConfig" - ) - client_create_method(*create_method_args) - - # Pipeline Config was instantiated with default arguments - assert mock_config.call_count == 1 - expected_kwargs = {} - assert mock_config.call_args == mocker.call(**expected_kwargs) - - # This default config was used for the protocol pipeline - assert mock_pipeline_init.call_count == 1 - assert mock_pipeline_init.call_args[0][1] == mock_config.return_value - - -@pytest.mark.describe("ProvisioningDeviceClient - Instantiation") -class TestClientInstantiation(object): - @pytest.mark.it( - "Stores the ProvisioningPipeline from the 'provisioning_pipeline' parameter in the '_provisioning_pipeline' attribute" - ) - def test_sets_provisioning_pipeline(self, provisioning_pipeline): - client = ProvisioningDeviceClient(provisioning_pipeline) - - assert client._provisioning_pipeline is provisioning_pipeline - - @pytest.mark.it( - "Instantiates with the initial value of the '_provisioning_payload' attribute set to None" - ) - def test_payload(self, provisioning_pipeline): - client = ProvisioningDeviceClient(provisioning_pipeline) - - assert client._provisioning_payload is None - - -@pytest.mark.describe("ProvisioningDeviceClient - .create_from_symmetric_key()") -class TestClientCreateFromSymmetricKey(SharedClientCreateMethodUserOptionTests): - @pytest.fixture - def client_create_method(self): - return ProvisioningDeviceClient.create_from_symmetric_key - - @pytest.fixture - def create_method_args(self): - return [fake_provisioning_host, fake_registration_id, fake_id_scope, fake_symmetric_key] - - @pytest.mark.it("Creates a SymmetricKeySecurityClient using the given parameters") - def test_security_client(self, mocker): - spy_sec_client = mocker.spy(security, "SymmetricKeySecurityClient") - - ProvisioningDeviceClient.create_from_symmetric_key( - provisioning_host=fake_provisioning_host, - registration_id=fake_registration_id, - id_scope=fake_id_scope, - symmetric_key=fake_symmetric_key, - ) - - assert spy_sec_client.call_count == 1 - assert spy_sec_client.call_args == mocker.call( - provisioning_host=fake_provisioning_host, - registration_id=fake_registration_id, - id_scope=fake_id_scope, - symmetric_key=fake_symmetric_key, - ) - - @pytest.mark.it( - "Uses the SymmetricKeySecurityClient object and the ProvisioningPipelineConfig object to create a ProvisioningPipeline" - ) - def test_pipeline(self, mocker, mock_pipeline_init): - # Note that the details of how the pipeline config is set up are covered in the - # SharedClientCreateMethodUserOptionTests - mock_pipeline_config = mocker.patch.object( - pipeline, "ProvisioningPipelineConfig" - ).return_value - mock_sec_client = mocker.patch.object(security, "SymmetricKeySecurityClient").return_value - - ProvisioningDeviceClient.create_from_symmetric_key( - provisioning_host=fake_provisioning_host, - registration_id=fake_registration_id, - id_scope=fake_id_scope, - symmetric_key=fake_symmetric_key, - ) - - assert mock_pipeline_init.call_count == 1 - assert mock_pipeline_init.call_args == mocker.call(mock_sec_client, mock_pipeline_config) - - @pytest.mark.it("Uses the ProvisioningPipeline to instantiate the client") - def test_client_creation(self, mocker, mock_pipeline_init): - spy_client_init = mocker.spy(ProvisioningDeviceClient, "__init__") - - ProvisioningDeviceClient.create_from_symmetric_key( - provisioning_host=fake_provisioning_host, - registration_id=fake_registration_id, - id_scope=fake_id_scope, - symmetric_key=fake_symmetric_key, - ) - assert spy_client_init.call_count == 1 - assert spy_client_init.call_args == mocker.call(mocker.ANY, mock_pipeline_init.return_value) +logging.basicConfig(level=logging.DEBUG) - @pytest.mark.it("Returns the instantiated client") - def test_returns_client(self, mocker): - client = ProvisioningDeviceClient.create_from_symmetric_key( - provisioning_host=fake_provisioning_host, - registration_id=fake_registration_id, - id_scope=fake_id_scope, - symmetric_key=fake_symmetric_key, - ) - assert isinstance(client, ProvisioningDeviceClient) +class ProvisioningClientTestsConfig(object): + """Defines fixtures for synchronous ProvisioningDeviceClient tests""" -@pytest.mark.describe("ProvisioningDeviceClient - .create_from_x509_certificate()") -class TestClientCreateFromX509Certificate(SharedClientCreateMethodUserOptionTests): @pytest.fixture - def client_create_method(self): - return ProvisioningDeviceClient.create_from_x509_certificate + def client_class(self): + return ProvisioningDeviceClient @pytest.fixture - def create_method_args(self, x509): - return [fake_provisioning_host, fake_registration_id, fake_id_scope, x509] - - @pytest.mark.it("Creates an X509SecurityClient using the given parameters") - def test_security_client(self, mocker, x509): - spy_sec_client = mocker.spy(security, "X509SecurityClient") - - ProvisioningDeviceClient.create_from_x509_certificate( - provisioning_host=fake_provisioning_host, - registration_id=fake_registration_id, - id_scope=fake_id_scope, - x509=x509, - ) + def client(self, provisioning_pipeline): + return ProvisioningDeviceClient(provisioning_pipeline) - assert spy_sec_client.call_count == 1 - assert spy_sec_client.call_args == mocker.call( - provisioning_host=fake_provisioning_host, - registration_id=fake_registration_id, - id_scope=fake_id_scope, - x509=x509, - ) - @pytest.mark.it( - "Uses the X509SecurityClient object and the ProvisioningPipelineConfig object to create a ProvisioningPipeline" - ) - def test_pipeline(self, mocker, mock_pipeline_init, x509): - # Note that the details of how the pipeline config is set up are covered in the - # SharedClientCreateMethodUserOptionTests - mock_pipeline_config = mocker.patch.object( - pipeline, "ProvisioningPipelineConfig" - ).return_value - mock_sec_client = mocker.patch.object(security, "X509SecurityClient").return_value - - ProvisioningDeviceClient.create_from_x509_certificate( - provisioning_host=fake_provisioning_host, - registration_id=fake_registration_id, - id_scope=fake_id_scope, - x509=x509, - ) - - assert mock_pipeline_init.call_count == 1 - assert mock_pipeline_init.call_args == mocker.call(mock_sec_client, mock_pipeline_config) +@pytest.mark.describe("ProvisioningDeviceClient (Sync) - Instantiation") +class TestProvisioningClientInstantiation( + ProvisioningClientTestsConfig, SharedProvisioningClientInstantiationTests +): + pass - @pytest.mark.it("Uses the ProvisioningPipeline to instantiate the client") - def test_client_creation(self, mocker, mock_pipeline_init, x509): - spy_client_init = mocker.spy(ProvisioningDeviceClient, "__init__") - ProvisioningDeviceClient.create_from_x509_certificate( - provisioning_host=fake_provisioning_host, - registration_id=fake_registration_id, - id_scope=fake_id_scope, - x509=x509, - ) +@pytest.mark.describe("ProvisioningDeviceClient (Sync) - .create_from_symmetric_key()") +class TestProvisioningClientCreateFromSymmetricKey( + ProvisioningClientTestsConfig, SharedProvisioningClientCreateFromSymmetricKeyTests +): + pass - assert spy_client_init.call_count == 1 - assert spy_client_init.call_args == mocker.call(mocker.ANY, mock_pipeline_init.return_value) - @pytest.mark.it("Returns the instantiated client") - def test_returns_client(self, mocker, x509): - client = ProvisioningDeviceClient.create_from_x509_certificate( - provisioning_host=fake_provisioning_host, - registration_id=fake_registration_id, - id_scope=fake_id_scope, - x509=x509, - ) - assert isinstance(client, ProvisioningDeviceClient) +@pytest.mark.describe("ProvisioningDeviceClient (Sync) - .create_from_x509_certificate()") +class TestProvisioningClientCreateFromX509Certificate( + ProvisioningClientTestsConfig, SharedProvisioningClientCreateFromX509CertificateTests +): + pass -@pytest.mark.describe("ProvisioningDeviceClient - .register()") +@pytest.mark.describe("ProvisioningDeviceClient (Sync) - .register()") class TestClientRegister(object): @pytest.mark.it("Implicitly enables responses from provisioning service if not already enabled") def test_enables_provisioning_only_if_not_already_enabled( @@ -438,7 +192,7 @@ def register_complete_failure_callback(payload, callback): assert provisioning_pipeline.register.call_count == 1 -@pytest.mark.describe("ProvisioningDeviceClient - .set_provisioning_payload()") +@pytest.mark.describe("ProvisioningDeviceClient (Sync) - .set_provisioning_payload()") class TestClientProvisioningPayload(object): @pytest.mark.it("Sets the payload on the provisioning payload attribute") @pytest.mark.parametrize( diff --git a/azure-iot-device/tests/test_product_info.py b/azure-iot-device/tests/test_product_info.py deleted file mode 100644 index e31bfd02a..000000000 --- a/azure-iot-device/tests/test_product_info.py +++ /dev/null @@ -1,69 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import pytest -from azure.iot.device.product_info import ProductInfo -import platform -from azure.iot.device.constant import VERSION, IOTHUB_IDENTIFIER, PROVISIONING_IDENTIFIER - - -check_agent_format = ( - "{identifier}/{version}({python_runtime};{os_type} {os_release};{architecture})" -) - - -@pytest.mark.describe("ProductInfo") -class TestProductInfo(object): - @pytest.mark.it( - "Contains python version, operating system and architecture of the system in the iothub agent string" - ) - def test_get_iothub_user_agent(self): - user_agent = ProductInfo.get_iothub_user_agent() - - assert IOTHUB_IDENTIFIER in user_agent - assert VERSION in user_agent - assert platform.python_version() in user_agent - assert platform.system() in user_agent - assert platform.version() in user_agent - assert platform.machine() in user_agent - - @pytest.mark.it("Checks if the format of the agent string is as expected") - def test_checks_format_iothub_agent(self): - expected_part_agent = check_agent_format.format( - identifier=IOTHUB_IDENTIFIER, - version=VERSION, - python_runtime=platform.python_version(), - os_type=platform.system(), - os_release=platform.version(), - architecture=platform.machine(), - ) - user_agent = ProductInfo.get_iothub_user_agent() - assert expected_part_agent in user_agent - - @pytest.mark.it( - "Contains python version, operating system and architecture of the system in the provisioning agent string" - ) - def test_get_provisioning_user_agent(self): - user_agent = ProductInfo.get_provisioning_user_agent() - - assert PROVISIONING_IDENTIFIER in user_agent - assert VERSION in user_agent - assert platform.python_version() in user_agent - assert platform.system() in user_agent - assert platform.version() in user_agent - assert platform.machine() in user_agent - - @pytest.mark.it("Checks if the format of the agent string is as expected") - def test_checks_format_provisioning_agent(self): - expected_part_agent = check_agent_format.format( - identifier=PROVISIONING_IDENTIFIER, - version=VERSION, - python_runtime=platform.python_version(), - os_type=platform.system(), - os_release=platform.version(), - architecture=platform.machine(), - ) - user_agent = ProductInfo.get_provisioning_user_agent() - assert expected_part_agent in user_agent diff --git a/azure-iot-device/tests/test_user_agent.py b/azure-iot-device/tests/test_user_agent.py new file mode 100644 index 000000000..e005a74d9 --- /dev/null +++ b/azure-iot-device/tests/test_user_agent.py @@ -0,0 +1,65 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import pytest +from azure.iot.device import user_agent +import platform +from azure.iot.device.constant import VERSION, IOTHUB_IDENTIFIER, PROVISIONING_IDENTIFIER + + +check_agent_format = ( + "{identifier}/{version}({python_runtime};{os_type} {os_release};{architecture})" +) + + +@pytest.mark.describe(".get_iothub_user_agent()") +class TestGetIothubUserAgent(object): + @pytest.mark.it( + "Returns a user agent string formatted for IoTHub, containing python version, operating system and architecture of the system" + ) + def test_get_iothub_user_agent(self): + user_agent_str = user_agent.get_iothub_user_agent() + + assert IOTHUB_IDENTIFIER in user_agent_str + assert VERSION in user_agent_str + assert platform.python_version() in user_agent_str + assert platform.system() in user_agent_str + assert platform.version() in user_agent_str + assert platform.machine() in user_agent_str + expected_part_agent = check_agent_format.format( + identifier=IOTHUB_IDENTIFIER, + version=VERSION, + python_runtime=platform.python_version(), + os_type=platform.system(), + os_release=platform.version(), + architecture=platform.machine(), + ) + assert expected_part_agent == user_agent_str + + +@pytest.mark.describe(".get_provisioning_user_agent()") +class TestGetProvisioningUserAgent(object): + @pytest.mark.it( + "Returns a user agent string formatted for the Provisioning Service, containing python version, operating system and architecture of the system" + ) + def test_get_provisioning_user_agent_str(self): + user_agent_str = user_agent.get_provisioning_user_agent() + + assert PROVISIONING_IDENTIFIER in user_agent_str + assert VERSION in user_agent_str + assert platform.python_version() in user_agent_str + assert platform.system() in user_agent_str + assert platform.version() in user_agent_str + assert platform.machine() in user_agent_str + + expected_part_agent = check_agent_format.format( + identifier=PROVISIONING_IDENTIFIER, + version=VERSION, + python_runtime=platform.python_version(), + os_type=platform.system(), + os_release=platform.version(), + architecture=platform.machine(), + ) + assert expected_part_agent == user_agent_str diff --git a/azure_provisioning_e2e/service_helper.py b/azure_provisioning_e2e/service_helper.py index b24602f30..32338c07c 100644 --- a/azure_provisioning_e2e/service_helper.py +++ b/azure_provisioning_e2e/service_helper.py @@ -9,8 +9,9 @@ ) from msrest.exceptions import HttpOperationError -from azure.iot.device.common.connection_string import ConnectionString -from azure.iot.device.common.sastoken import SasToken +from azure.iot.device.common.auth.connection_string import ConnectionString +from azure.iot.device.common.auth.sastoken import SasToken +from azure.iot.device.common.auth.signing_mechanism import SymmetricKeySigningMechanism import uuid import time import random @@ -26,9 +27,10 @@ def connection_string_to_sas_token(conn_str): signature that can be used to connect to the given hub """ conn_str_obj = ConnectionString(conn_str) + signing_mechanism = SymmetricKeySigningMechanism(conn_str_obj.get("SharedAccessKey")) sas_token = SasToken( uri=conn_str_obj.get("HostName"), - key=conn_str_obj.get("SharedAccessKey"), + signing_mechanism=signing_mechanism, key_name=conn_str_obj.get("SharedAccessKeyName"), ) diff --git a/credscan_suppression.json b/credscan_suppression.json index 3c738de48..810a84f4d 100644 --- a/credscan_suppression.json +++ b/credscan_suppression.json @@ -8,6 +8,14 @@ { "file": "\\azure_provisioning_e2e\\tests\\test_sync_certificate_enrollments.py", "_justification": "Test containing fake passwords and keys" + }, + { + "file": "\\azure-iot-device\\tests\\common\\auth\\test_signing_mechanism.py", + "_justification:": "Test containing fake keys" + }, + { + "file": "\\azure-iot-device\\tests\\iothub\\client_fixtures.py", + "_justification": "Test containing fake keys" } ]