Skip to content

Commit

Permalink
feat(databases): test connection api (#10723)
Browse files Browse the repository at this point in the history
* test connection api on databases

* update test connection tests

* update database api test and open api description

* moved test connection to commands

* update error message

* fix isort

* fix mypy

* fix black

* fix mypy pre commit
  • Loading branch information
Lily Kuang authored Sep 9, 2020
1 parent 9a59bdd commit 8a3ac70
Show file tree
Hide file tree
Showing 10 changed files with 316 additions and 23 deletions.
2 changes: 1 addition & 1 deletion superset/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ def init_views(self) -> None:
AlertLogModelView,
AlertModelView,
AlertObservationModelView,
ValidatorInlineView,
SQLObserverInlineView,
ValidatorInlineView,
)
from superset.views.annotations import (
AnnotationLayerModelView,
Expand Down
99 changes: 95 additions & 4 deletions superset/databases/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,15 @@
from flask import g, request, Response
from flask_appbuilder.api import expose, protect, rison, safe
from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_babel import gettext as _
from marshmallow import ValidationError
from sqlalchemy.exc import NoSuchTableError, OperationalError, SQLAlchemyError
from sqlalchemy.engine.url import make_url
from sqlalchemy.exc import (
NoSuchModuleError,
NoSuchTableError,
OperationalError,
SQLAlchemyError,
)

from superset import event_logger
from superset.constants import RouteMethod
Expand All @@ -33,8 +40,10 @@
DatabaseDeleteFailedError,
DatabaseInvalidError,
DatabaseNotFoundError,
DatabaseSecurityUnsafeError,
DatabaseUpdateFailedError,
)
from superset.databases.commands.test_connection import TestConnectionDatabaseCommand
from superset.databases.commands.update import UpdateDatabaseCommand
from superset.databases.dao import DatabaseDAO
from superset.databases.decorators import check_datasource_access
Expand All @@ -44,6 +53,7 @@
DatabasePostSchema,
DatabasePutSchema,
DatabaseRelatedObjectsResponse,
DatabaseTestConnectionSchema,
SchemasResponseSchema,
SelectStarResponseSchema,
TableMetadataResponseSchema,
Expand All @@ -65,6 +75,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
"table_metadata",
"select_star",
"schemas",
"test_connection",
"related_objects",
}
class_permission_name = "DatabaseView"
Expand Down Expand Up @@ -343,7 +354,7 @@ def delete(self, pk: int) -> Response: # pylint: disable=arguments-differ
@rison(database_schemas_query_schema)
@statsd_metrics
def schemas(self, pk: int, **kwargs: Any) -> FlaskResponse:
""" Get all schemas from a database
"""Get all schemas from a database
---
get:
description: Get all schemas from a database
Expand Down Expand Up @@ -400,7 +411,7 @@ def schemas(self, pk: int, **kwargs: Any) -> FlaskResponse:
def table_metadata(
self, database: Database, table_name: str, schema_name: str
) -> FlaskResponse:
""" Table schema info
"""Table schema info
---
get:
description: Get database table metadata
Expand Down Expand Up @@ -457,7 +468,7 @@ def table_metadata(
def select_star(
self, database: Database, table_name: str, schema_name: Optional[str] = None
) -> FlaskResponse:
""" Table schema info
"""Table schema info
---
get:
description: Get database select star for table
Expand Down Expand Up @@ -506,6 +517,86 @@ def select_star(
self.incr_stats("success", self.select_star.__name__)
return self.response(200, result=result)

@expose("/test_connection", methods=["POST"])
@protect()
@safe
@event_logger.log_this
@statsd_metrics
def test_connection( # pylint: disable=too-many-return-statements
self,
) -> FlaskResponse:
"""Tests a database connection
---
post:
description: >-
Tests a database connection
requestBody:
description: Database schema
required: true
content:
application/json:
schema:
type: object
properties:
encrypted_extra:
type: object
extras:
type: object
name:
type: string
server_cert:
type: string
responses:
200:
description: Database Test Connection
content:
application/json:
schema:
type: object
properties:
message:
type: string
400:
$ref: '#/components/responses/400'
422:
$ref: '#/components/responses/422'
500:
$ref: '#/components/responses/500'
"""
if not request.is_json:
return self.response_400(message="Request is not JSON")
try:
item = DatabaseTestConnectionSchema().load(request.json)
# This validates custom Schema with custom validations
except ValidationError as error:
return self.response_400(message=error.messages)
try:
TestConnectionDatabaseCommand(g.user, item).run()
return self.response(200, message="OK")
except (NoSuchModuleError, ModuleNotFoundError):
logger.info("Invalid driver")
driver_name = make_url(item.get("sqlalchemy_uri")).drivername
return self.response(
400,
message=_(f"Could not load database driver: {driver_name}"),
driver_name=driver_name,
)
except DatabaseSecurityUnsafeError as ex:
return self.response_422(message=ex)
except OperationalError:
logger.warning("Connection failed")
return self.response(
500,
message=_("Connection failed, please check your connection settings"),
)
except Exception as ex: # pylint: disable=broad-except
logger.error("Unexpected error %s", type(ex).__name__)
return self.response_400(
message=_(
"Unexpected error occurred, please check your logs for details"
)
)

@expose("/<int:pk>/related_objects/", methods=["GET"])
@protect()
@safe
Expand Down
5 changes: 5 additions & 0 deletions superset/databases/commands/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
DeleteFailedError,
UpdateFailedError,
)
from superset.security.analytics_db_safety import DBSecurityException


class DatabaseInvalidError(CommandInvalidError):
Expand Down Expand Up @@ -109,3 +110,7 @@ class DatabaseDeleteDatasetsExistFailedError(DeleteFailedError):

class DatabaseDeleteFailedError(DeleteFailedError):
message = _("Database could not be deleted.")


class DatabaseSecurityUnsafeError(DBSecurityException):
message = _("Stopped an unsafe database connection")
67 changes: 67 additions & 0 deletions superset/databases/commands/test_connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from contextlib import closing
from typing import Any, Dict, Optional

import simplejson as json
from flask_appbuilder.security.sqla.models import User
from sqlalchemy import select

from superset.commands.base import BaseCommand
from superset.databases.commands.exceptions import DatabaseSecurityUnsafeError
from superset.databases.dao import DatabaseDAO
from superset.models.core import Database
from superset.security.analytics_db_safety import DBSecurityException

logger = logging.getLogger(__name__)


class TestConnectionDatabaseCommand(BaseCommand):
def __init__(self, user: User, data: Dict[str, Any]):
self._actor = user
self._properties = data.copy()
self._model: Optional[Database] = None

def run(self) -> None:
self.validate()
try:
uri = self._properties.get("sqlalchemy_uri", "")
if self._model and uri == self._model.safe_sqlalchemy_uri():
uri = self._model.sqlalchemy_uri_decrypted

database = DatabaseDAO.build_db_for_connection_test(
server_cert=self._properties.get("server_cert", ""),
extra=json.dumps(self._properties.get("extra", {})),
impersonate_user=self._properties.get("impersonate_user", False),
encrypted_extra=json.dumps(self._properties.get("encrypted_extra", {})),
)
if database is not None:
database.set_sqlalchemy_uri(uri)
database.db_engine_spec.mutate_db_for_connection_test(database)
username = self._actor.username if self._actor is not None else None
engine = database.get_sqla_engine(user_name=username)
with closing(engine.connect()) as conn:
conn.scalar(select([1]))
except DBSecurityException as ex:
logger.warning(ex)
raise DatabaseSecurityUnsafeError()

def validate(self) -> None:
database_name = self._properties.get("database_name")
if database_name is not None:
self._model = DatabaseDAO.get_database_by_name(database_name)
21 changes: 20 additions & 1 deletion superset/databases/dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Dict
from typing import Any, Dict, Optional

from superset.dao.base import BaseDAO
from superset.databases.filters import DatabaseFilter
Expand Down Expand Up @@ -45,6 +45,25 @@ def validate_update_uniqueness(database_id: int, database_name: str) -> bool:
)
return not db.session.query(database_query.exists()).scalar()

@staticmethod
def get_database_by_name(database_name: str) -> Optional[Database]:
return (
db.session.query(Database)
.filter(Database.database_name == database_name)
.one_or_none()
)

@staticmethod
def build_db_for_connection_test(
server_cert: str, extra: str, impersonate_user: bool, encrypted_extra: str
) -> Optional[Database]:
return Database(
server_cert=server_cert,
extra=extra,
impersonate_user=impersonate_user,
encrypted_extra=encrypted_extra,
)

@classmethod
def get_related_objects(cls, database_id: int) -> Dict[str, Any]:
datasets = cls.find_by_id(database_id).tables
Expand Down
23 changes: 21 additions & 2 deletions superset/databases/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
import inspect
import json

from flask import current_app
from flask_babel import lazy_gettext as _
from marshmallow import fields, Schema
from marshmallow.validate import Length, ValidationError
from sqlalchemy import MetaData
from sqlalchemy.engine.url import make_url
from sqlalchemy.exc import ArgumentError

from superset import app
from superset.exceptions import CertificateException
from superset.utils.core import markdown, parse_ssl_cert

Expand Down Expand Up @@ -142,7 +142,7 @@ def sqlalchemy_uri_validator(value: str) -> str:
)
]
)
if app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] and value:
if current_app.config.get("PREVENT_UNSAFE_DB_CONNECTIONS", True) and value:
if value.startswith("sqlite"):
raise ValidationError(
[
Expand Down Expand Up @@ -291,6 +291,25 @@ class DatabasePutSchema(Schema):
)


class DatabaseTestConnectionSchema(Schema):
database_name = fields.String(
description=database_name_description, allow_none=True, validate=Length(1, 250),
)
impersonate_user = fields.Boolean(description=impersonate_user_description)
extra = fields.String(description=extra_description, validate=extra_validator)
encrypted_extra = fields.String(
description=encrypted_extra_description, validate=encrypted_extra_validator
)
server_cert = fields.String(
description=server_cert_description, validate=server_cert_validator
)
sqlalchemy_uri = fields.String(
description=sqlalchemy_uri_description,
required=True,
validate=[Length(1, 1024), sqlalchemy_uri_validator],
)


class TableMetadataOptionsResponseSchema(Schema):
deferrable = fields.Bool()
initially = fields.Bool()
Expand Down
2 changes: 1 addition & 1 deletion superset/tasks/schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@

if TYPE_CHECKING:
# pylint: disable=unused-import
from werkzeug.datastructures import TypeConversionDict
from flask_appbuilder.security.sqla.models import User
from werkzeug.datastructures import TypeConversionDict

# Globals
config = app.config
Expand Down
25 changes: 13 additions & 12 deletions superset/views/base_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,24 +91,25 @@ class BaseSupersetModelRestApi(ModelRestApi):

csrf_exempt = False
method_permission_name = {
"get_list": "list",
"get": "show",
"bulk_delete": "delete",
"data": "list",
"delete": "delete",
"distinct": "list",
"export": "mulexport",
"get": "show",
"get_list": "list",
"info": "list",
"post": "add",
"put": "edit",
"delete": "delete",
"bulk_delete": "delete",
"info": "list",
"related": "list",
"distinct": "list",
"thumbnail": "list",
"refresh": "edit",
"data": "list",
"viz_types": "list",
"related": "list",
"related_objects": "list",
"table_metadata": "list",
"select_star": "list",
"schemas": "list",
"select_star": "list",
"table_metadata": "list",
"test_connection": "post",
"thumbnail": "list",
"viz_types": "list",
}

order_rel_fields: Dict[str, Tuple[str, str]] = {}
Expand Down
2 changes: 1 addition & 1 deletion superset/views/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,7 +1162,7 @@ def testconn( # pylint: disable=too-many-return-statements,no-self-use
logger.warning("Stopped an unsafe database connection")
return json_error_response(_(str(ex)), 400)
except Exception as ex: # pylint: disable=broad-except
logger.error("Unexpected error %s", type(ex).__name__)
logger.warning("Unexpected error %s", type(ex).__name__)
return json_error_response(
_("Unexpected error occurred, please check your logs for details"), 400
)
Expand Down
Loading

0 comments on commit 8a3ac70

Please # to comment.