Skip to content

Commit

Permalink
update with enums
Browse files Browse the repository at this point in the history
  • Loading branch information
hughhhh authored May 11, 2022
1 parent 1dc9c74 commit bcfc683
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 71 deletions.
127 changes: 61 additions & 66 deletions superset/dao/datasource/dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,37 +15,38 @@
# specific language governing permissions and limitations
# under the License.

from typing import List, Optional, Set, Union
from typing import Any, Dict, List, Optional, Set, Union

from flask_babel import _
from sqlalchemy import or_
from sqlalchemy.orm import Session, subqueryload
from sqlalchemy.orm.exc import NoResultFound

from superset.connectors.sqla.models import SqlaTable, Table
from superset.connectors.sqla.models import SqlaTable
from superset.dao.base import BaseDAO
from superset.datasets.commands.exceptions import DatasetNotFoundError
from superset.datasets.models import Dataset
from superset.models.core import Database
from superset.models.sql_lab import Query, SavedQuery
from superset.tables.models import Table
from superset.utils.core import DatasourceType

Datasource = Union[Dataset, SqlaTable, Table, Query, SavedQuery]
Datasource = Union[Dataset, SqlaTable, Table, Query, SavedQuery, Any]


class DatasourceDAO(BaseDAO):

sources = {
# using table -> SqlaTable for backward compatibility at the moment
"table": SqlaTable,
"query": Query,
"saved_query": SavedQuery,
"sl_dataset": Dataset,
"sl_table": Table,
sources: Dict[DatasourceType, Datasource] = {
DatasourceType.SQLATABLE: SqlaTable,
DatasourceType.QUERY: Query,
DatasourceType.SAVEDQUERY: SavedQuery,
DatasourceType.DATASET: Dataset,
DatasourceType.TABLE: Table,
}

@classmethod
def get_datasource(
cls, datasource_type: str, datasource_id: int, session: Session
cls, datasource_type: DatasourceType, datasource_id: int, session: Session
) -> Datasource:
if datasource_type not in cls.sources:
raise DatasetNotFoundError()
Expand All @@ -61,91 +62,85 @@ def get_datasource(

return datasource

def get_all_datasources(self, session: Session) -> List[Datasource]:
datasources: List["Datasource"] = []
@classmethod
def get_all_datasources(cls, session: Session) -> List[Datasource]:
datasources: List[Datasource] = []
for source_class in DatasourceDAO.sources.values():
qry = session.query(source_class)
qry = source_class.default_query(qry)
if isinstance(source_class, SqlaTable):
qry = source_class.default_query(qry)
datasources.extend(qry.all())
return datasources

def get_datasource_by_id(self, session: Session, datasource_id: int) -> Datasource:
"""
Find a datasource instance based on the unique id.
:param session: Session to use
:param datasource_id: unique id of datasource
:return: Datasource corresponding to the id
:raises NoResultFound: if no datasource is found corresponding to the id
"""
for datasource_class in DatasourceDAO.sources.values():
try:
return (
session.query(datasource_class)
.filter(datasource_class.id == datasource_id)
.one()
)
except NoResultFound:
# proceed to next datasource type
pass
raise NoResultFound(_("Datasource id not found: %(id)s", id=datasource_id))

@classmethod
def get_datasource_by_name( # pylint: disable=too-many-arguments
self,
cls,
session: Session,
datasource_type: str,
datasource_type: DatasourceType,
datasource_name: str,
schema: str,
database_name: str,
) -> Optional[Datasource]:
datasource_class = DatasourceDAO.sources[datasource_type]
return datasource_class.get_datasource_by_name(
session, datasource_name, schema, database_name
)
if isinstance(datasource_class, SqlaTable):
return datasource_class.get_datasource_by_name(
session, datasource_name, schema, database_name
)
return None

@classmethod
def query_datasources_by_permissions( # pylint: disable=invalid-name
self,
cls,
session: Session,
database: Database,
permissions: Set[str],
schema_perms: Set[str],
) -> List[Datasource]:
# TODO(bogdan): add unit test
datasource_class = DatasourceDAO.sources[database.type]
return (
session.query(datasource_class)
.filter_by(database_id=database.id)
.filter(
or_(
datasource_class.perm.in_(permissions),
datasource_class.schema_perm.in_(schema_perms),
datasource_class = DatasourceDAO.sources[DatasourceType[database.type]]
if isinstance(datasource_class, SqlaTable):
return (
session.query(datasource_class)
.filter_by(database_id=database.id)
.filter(
or_(
datasource_class.perm.in_(permissions),
datasource_class.schema_perm.in_(schema_perms),
)
)
.all()
)
.all()
)
return []

@classmethod
def get_eager_datasource(
self, session: Session, datasource_type: str, datasource_id: int
) -> Datasource:
cls, session: Session, datasource_type: str, datasource_id: int
) -> Optional[Datasource]:
"""Returns datasource with columns and metrics."""
datasource_class = DatasourceDAO.sources[datasource_type]
return (
session.query(datasource_class)
.options(
subqueryload(datasource_class.columns),
subqueryload(datasource_class.metrics),
datasource_class = DatasourceDAO.sources[DatasourceType[datasource_type]]
if isinstance(datasource_class, SqlaTable):
return (
session.query(datasource_class)
.options(
subqueryload(datasource_class.columns),
subqueryload(datasource_class.metrics),
)
.filter_by(id=datasource_id)
.one()
)
.filter_by(id=datasource_id)
.one()
)
return None

@classmethod
def query_datasources_by_name(
self,
cls,
session: Session,
database: Database,
datasource_name: str,
schema: Optional[str] = None,
) -> List[Datasource]:
datasource_class = DatasourceDAO.sources[database.type]
return datasource_class.query_datasources_by_name(
session, database, datasource_name, schema=schema
)
datasource_class = DatasourceDAO.sources[DatasourceType[database.type]]
if isinstance(datasource_class, SqlaTable):
return datasource_class.query_datasources_by_name(
session, database, datasource_name, schema=schema
)
return []
75 changes: 70 additions & 5 deletions tests/unit_tests/dao/datasource_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,16 @@
import pytest
from sqlalchemy.orm.session import Session

from superset.utils.core import DatasourceType


def create_test_data(session: Session) -> None:
from superset.columns.models import Column
from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.datasets.models import Dataset
from superset.models.core import Database
from superset.models.sql_lab import Query, SavedQuery
from superset.tables.models import Table

engine = session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
Expand Down Expand Up @@ -58,6 +62,32 @@ def create_test_data(session: Session) -> None:

saved_query = SavedQuery(database=db, sql="select * from foo")

table = Table(
name="my_table",
schema="my_schema",
catalog="my_catalog",
database=db,
columns=[],
)

dataset = Dataset(
database=table.database,
name="positions",
expression="""
SELECT array_agg(array[longitude,latitude]) AS position
FROM my_catalog.my_schema.my_table
""",
tables=[table],
columns=[
Column(
name="position",
expression="array_agg(array[longitude,latitude])",
),
],
)

session.add(dataset)
session.add(table)
session.add(saved_query)
session.add(query_obj)
session.add(db)
Expand All @@ -72,7 +102,7 @@ def test_get_datasource_sqlatable(app_context: None, session: Session) -> None:
create_test_data(session)

result = DatasourceDAO.get_datasource(
datasource_type="table", datasource_id=1, session=session
datasource_type=DatasourceType.SQLATABLE, datasource_id=1, session=session
)

assert 1 == result.id
Expand All @@ -87,7 +117,7 @@ def test_get_datasource_query(app_context: None, session: Session) -> None:
create_test_data(session)

result = DatasourceDAO.get_datasource(
datasource_type="query", datasource_id=1, session=session
datasource_type=DatasourceType.QUERY, datasource_id=1, session=session
)

assert result.id == 1
Expand All @@ -101,16 +131,51 @@ def test_get_datasource_saved_query(app_context: None, session: Session) -> None
create_test_data(session)

result = DatasourceDAO.get_datasource(
datasource_type="saved_query", datasource_id=1, session=session
datasource_type=DatasourceType.SAVEDQUERY, datasource_id=1, session=session
)

assert result.id == 1
assert isinstance(result, SavedQuery)


def test_get_datasource_sl_table(app_context: None, session: Session) -> None:
pass
from superset.dao.datasource.dao import DatasourceDAO
from superset.tables.models import Table

create_test_data(session)

# todo(hugh): This will break once we remove the dual write
# update the datsource_id=1 and this will pass again
result = DatasourceDAO.get_datasource(
datasource_type=DatasourceType.TABLE, datasource_id=2, session=session
)

assert result.id == 2
assert isinstance(result, Table)


def test_get_datasource_sl_dataset(app_context: None, session: Session) -> None:
pass
from superset.dao.datasource.dao import DatasourceDAO
from superset.datasets.models import Dataset

create_test_data(session)

# todo(hugh): This will break once we remove the dual write
# update the datsource_id=1 and this will pass again
result = DatasourceDAO.get_datasource(
datasource_type=DatasourceType.DATASET, datasource_id=2, session=session
)

assert result.id == 2
assert isinstance(result, Dataset)


def test_get_all_datasources(app_context: None, session: Session) -> None:
from superset.dao.datasource.dao import DatasourceDAO

create_test_data(session)

# todo(hugh): This will break once we remove the dual write
# update the assert len(result) == 5 and this will pass again
result = DatasourceDAO.get_all_datasources(session=session)
assert len(result) == 7

0 comments on commit bcfc683

Please # to comment.