diff --git a/aws_gate/cli.py b/aws_gate/cli.py index ac37c2fd..745edfd3 100644 --- a/aws_gate/cli.py +++ b/aws_gate/cli.py @@ -28,6 +28,7 @@ from aws_gate.ssh import ssh from aws_gate.ssh_config import ssh_config from aws_gate.ssh_proxy import ssh_proxy +from aws_gate.port_forward import port_forward from aws_gate.utils import get_default_region logger = logging.getLogger(__name__) @@ -88,6 +89,25 @@ def get_argument_parser(*args, **kwargs): "instance_name", help="Instance we wish to open session to" ) + # 'port-forward' subcommand + port_forward_parser = subparsers.add_parser( + "port-forward", help="Open new session on instance and forward to a port locally or remotely" + ) + port_forward_parser.add_argument("-p", "--profile", help="AWS profile to use") + port_forward_parser.add_argument("-r", "--region", help="AWS region to use") + port_forward_parser.add_argument( + "instance_name", help="Instance we wish to open session to" + ) + port_forward_parser.add_argument( + "target_port", help="Port to forward to", type=int + ) + port_forward_parser.add_argument( + "--target_host", help="Host to forward into", default=None + ) + port_forward_parser.add_argument( + "--local_port", help="Local port to forward to", type=int, default=7000 + ) + # 'ssh' subcommand ssh_parser = subparsers.add_parser( "ssh", help="Open SSH session on instance and connect to it" @@ -284,6 +304,16 @@ def main(args=None, argument_parser=None): region_name=region, profile_name=profile, ) + elif args.subcommand == "port-forward": + port_forward( + config=config, + instance_name=args.instance_name, + target_host=args.target_host, + region_name=region, + profile_name=profile, + target_port=args.target_port, + local_port=args.local_port, + ) elif args.subcommand == "ssh": ssh( config=config, diff --git a/aws_gate/port_forward.py b/aws_gate/port_forward.py new file mode 100644 index 00000000..1135dd29 --- /dev/null +++ b/aws_gate/port_forward.py @@ -0,0 +1,111 @@ +import logging + +from aws_gate.constants import AWS_DEFAULT_PROFILE, AWS_DEFAULT_REGION +from aws_gate.decorators import ( + plugin_version, + plugin_required, + valid_aws_profile, + valid_aws_region, +) +from aws_gate.query import query_instance +from aws_gate.session_common import BaseSession +from aws_gate.utils import ( + get_aws_client, + get_aws_resource, + fetch_instance_details_from_config, +) + +logger = logging.getLogger(__name__) + + +class SSMPortForwardSession(BaseSession): + def __init__( + self, + instance_id, + target_port: int, + target_host=None, + region_name=AWS_DEFAULT_REGION, + profile_name=AWS_DEFAULT_PROFILE, + local_port: int = 7000, + ssm=None, + ): + self._instance_id = instance_id + self._region_name = region_name + self._profile_name = profile_name if profile_name is not None else "" + self._ssm = ssm + self._target_host = target_host + self._target_port = target_port + self._local_port = local_port + + forward_parameters = { + "portNumber": [str(self._target_port)], + "localPortNumber": [str(self._local_port)], + } + + # local forward or remote forward + if self._target_host is None: + document_name = "AWS-StartPortForwardingSession" + else: + document_name = "AWS-StartPortForwardingSessionToRemoteHost" + forward_parameters.update({"host": [self._target_host]}) + + start_session_kwargs = dict( + Target=self._instance_id, + DocumentName=document_name, + Parameters=forward_parameters, + ) + + self._session_parameters = start_session_kwargs + + +@plugin_required +@plugin_version("1.1.23.0") +@valid_aws_profile +@valid_aws_region +def port_forward( + config, + instance_name, + target_host, + target_port, + local_port=7000, + profile_name=AWS_DEFAULT_PROFILE, + region_name=AWS_DEFAULT_REGION, +): + instance, profile, region = fetch_instance_details_from_config( + config, instance_name, profile_name, region_name + ) + + ssm = get_aws_client("ssm", region_name=region, profile_name=profile) + ec2 = get_aws_resource("ec2", region_name=region, profile_name=profile) + + instance_id = query_instance(name=instance, ec2=ec2) + if instance_id is None: + raise ValueError(f"No instance could be found for name: {instance}") + + if target_host is None: + logger.info( + "Opening SSM Port Forwarding Session listening on %s in instance %s (%s) via profile %s to %s:%s", + target_port, + instance_id, + region, + profile, + ) + else: + logger.info( + "Opening SSM Port Forwarding Session to %s:%s via instance %s (%s) via profile %s to %s:%s", + target_host, + target_port, + instance_id, + region, + profile, + ) + with SSMPortForwardSession( + instance_id, + region_name=region, + profile_name=profile, + ssm=ssm, + target_host=target_host, + target_port=target_port, + local_port=local_port, + ) as sess: + sess.open() diff --git a/tests/unit/test_port_forward.py b/tests/unit/test_port_forward.py new file mode 100644 index 00000000..1738dd47 --- /dev/null +++ b/tests/unit/test_port_forward.py @@ -0,0 +1,159 @@ +import pytest + +from aws_gate.port_forward import port_forward, SSMPortForwardSession + + +def test_create_ssm_forward_session(ssm_mock, instance_id): + sess = SSMPortForwardSession( + instance_id=instance_id, ssm=ssm_mock, target_host="localhost", target_port=1234 + ) + sess.create() + + assert ssm_mock.start_session.called + + +def test_terminate_ssm_forward_session(ssm_mock, instance_id): + sess = SSMPortForwardSession( + instance_id=instance_id, ssm=ssm_mock, target_host="localhost", target_port=1234 + ) + + sess.create() + sess.terminate() + + assert ssm_mock.terminate_session.called + + +@pytest.mark.parametrize( + "target_host", + [ + None, + "my-fun-host", + ], + ids=["Target is None, Local Forward", "Target is not None, Remote Forward"], +) +def test_open_ssm_forward_session(mocker, instance_id, ssm_mock, target_host): + m = mocker.patch("aws_gate.session_common.execute_plugin", return_value="output") + + sess = SSMPortForwardSession( + instance_id=instance_id, ssm=ssm_mock, target_host=target_host, target_port=1234 + ) + sess.open() + + if target_host: + expected_doc_name = "AWS-StartPortForwardingSessionToRemoteHost" + else: + expected_doc_name = "AWS-StartPortForwardingSession" + + assert sess._session_parameters.get("DocumentName") == expected_doc_name + assert m.called + + +def test_ssm_forward_session_context_manager(ssm_mock, instance_id): + with SSMPortForwardSession( + instance_id=instance_id, ssm=ssm_mock, target_host="localhost", target_port=1234 + ): + pass + + assert ssm_mock.start_session.called + assert ssm_mock.terminate_session.called + + +def test_port_forward(mocker, instance_id, config): + mocker.patch("aws_gate.port_forward.get_aws_client") + mocker.patch("aws_gate.port_forward.get_aws_resource") + mocker.patch("aws_gate.port_forward.query_instance", return_value=instance_id) + port_forward_mock = mocker.patch( + "aws_gate.port_forward.SSMPortForwardSession", return_value=mocker.MagicMock() + ) + mocker.patch("aws_gate.decorators.is_existing_region", return_value=True) + mocker.patch("aws_gate.decorators._plugin_exists", return_value=True) + mocker.patch("aws_gate.decorators.execute_plugin", return_value="1.1.23.0") + + port_forward( + config=config, + instance_name="instance_name", + target_host="target_host", + target_port=22, + profile_name="default", + region_name="eu-west-1", + ) + + assert port_forward_mock.called + + +def test_port_forward_exception_invalid_profile(mocker, instance_id, config): + mocker.patch("aws_gate.port_forward.get_aws_client") + mocker.patch("aws_gate.port_forward.get_aws_resource") + mocker.patch("aws_gate.port_forward.query_instance", return_value=instance_id) + mocker.patch("aws_gate.decorators.is_existing_region", return_value=True) + mocker.patch("aws_gate.decorators._plugin_exists", return_value=True) + mocker.patch("aws_gate.decorators.execute_plugin", return_value="1.1.23.0") + + with pytest.raises(ValueError): + port_forward( + config=config, + instance_name="instance_name", + target_host="target_host", + target_port=22, + profile_name="invalid-default", + region_name="eu-west-1", + ) + + +def test_port_forward_exception_invalid_region(mocker, instance_id, config): + mocker.patch("aws_gate.port_forward.get_aws_client") + mocker.patch("aws_gate.port_forward.get_aws_resource") + mocker.patch("aws_gate.port_forward.query_instance", return_value=instance_id) + mocker.patch("aws_gate.decorators.is_existing_profile", return_value=True) + mocker.patch("aws_gate.decorators._plugin_exists", return_value=True) + mocker.patch("aws_gate.decorators.execute_plugin", return_value="1.1.23.0") + mocker.patch( + "aws_gate.port_forward.SSMPortForwardSession", return_value=mocker.MagicMock() + ) + with pytest.raises(ValueError): + port_forward( + config=config, + region_name="not-a-region", + instance_name="instance_name", + target_port=22, + profile_name="default", + target_host="target_host", + ) + + +def test_port_forward_exception_unknown_instance_id(mocker, instance_id, config): + mocker.patch("aws_gate.port_forward.get_aws_client") + mocker.patch("aws_gate.port_forward.get_aws_resource") + mocker.patch("aws_gate.port_forward.query_instance", return_value=None) + mocker.patch("aws_gate.decorators.is_existing_profile", return_value=True) + mocker.patch("aws_gate.decorators.is_existing_region", return_value=True) + mocker.patch("aws_gate.decorators._plugin_exists", return_value=True) + mocker.patch("aws_gate.decorators.execute_plugin", return_value="1.1.23.0") + with pytest.raises(ValueError): + port_forward( + config=config, + region_name="ap-southeast-2", + instance_name=instance_id, + target_port=22, + profile_name="default", + target_host="target_host", + ) + + +def test_port_forward_exception_without_config(mocker, instance_id, empty_config): + mocker.patch("aws_gate.port_forward.get_aws_client") + mocker.patch("aws_gate.port_forward.get_aws_resource") + mocker.patch("aws_gate.port_forward.query_instance", return_value=None) + mocker.patch("aws_gate.decorators.is_existing_profile", return_value=True) + mocker.patch("aws_gate.decorators.is_existing_region", return_value=True) + mocker.patch("aws_gate.decorators._plugin_exists", return_value=True) + mocker.patch("aws_gate.decorators.execute_plugin", return_value="1.1.23.0") + with pytest.raises(ValueError): + port_forward( + config=empty_config, + region_name="ap-southeast-2", + instance_name=instance_id, + target_port=22, + profile_name="default", + target_host="target_host", + )