diff --git a/superset/datasets/api.py b/superset/datasets/api.py index 276366c6b054f..e9d7e7f32f9d9 100644 --- a/superset/datasets/api.py +++ b/superset/datasets/api.py @@ -196,7 +196,7 @@ class DatasetRestApi(BaseSupersetModelRestApi): add_model_schema = DatasetPostSchema() edit_model_schema = DatasetPutSchema() duplicate_model_schema = DatasetDuplicateSchema() - add_columns = ["database", "schema", "table_name", "owners"] + add_columns = ["database", "schema", "table_name", "sql", "owners"] edit_columns = [ "table_name", "sql", diff --git a/superset/datasets/commands/create.py b/superset/datasets/commands/create.py index 1fa2e0ccf7660..809eec7a1159a 100644 --- a/superset/datasets/commands/create.py +++ b/superset/datasets/commands/create.py @@ -59,6 +59,7 @@ def validate(self) -> None: database_id = self._properties["database"] table_name = self._properties["table_name"] schema = self._properties.get("schema", None) + sql = self._properties.get("sql", None) owner_ids: Optional[List[int]] = self._properties.get("owners") # Validate uniqueness @@ -71,9 +72,12 @@ def validate(self) -> None: exceptions.append(DatabaseNotFoundValidationError()) self._properties["database"] = database - # Validate table exists on dataset - if database and not DatasetDAO.validate_table_exists( - database, table_name, schema + # Validate table exists on dataset if sql is not provided + # This should be validated when the dataset is physical + if ( + database + and not sql + and not DatasetDAO.validate_table_exists(database, table_name, schema) ): exceptions.append(TableNotFoundValidationError(table_name)) diff --git a/superset/datasets/schemas.py b/superset/datasets/schemas.py index 9d2b474894b02..223324da3aa9b 100644 --- a/superset/datasets/schemas.py +++ b/superset/datasets/schemas.py @@ -80,6 +80,7 @@ class DatasetPostSchema(Schema): database = fields.Integer(required=True) schema = fields.String(validate=Length(0, 250)) table_name = fields.String(required=True, allow_none=False, validate=Length(1, 250)) + sql = fields.String(allow_none=True) owners = fields.List(fields.Integer()) is_managed_externally = fields.Boolean(allow_none=True, default=False) external_url = fields.String(allow_none=True) diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index 2045a0fdcf384..ef003d05dc600 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -608,6 +608,61 @@ def test_create_dataset_validate_uniqueness(self): "message": {"table_name": ["Dataset energy_usage already exists"]} } + @pytest.mark.usefixtures("load_energy_table_with_slice") + def test_create_dataset_with_sql_validate_uniqueness(self): + """ + Dataset API: Test create dataset with sql + """ + if backend() == "sqlite": + return + + schema = get_example_default_schema() + energy_usage_ds = self.get_energy_usage_dataset() + self.login(username="admin") + table_data = { + "database": energy_usage_ds.database_id, + "table_name": energy_usage_ds.table_name, + "sql": "select * from energy_usage", + } + if schema: + table_data["schema"] = schema + rv = self.post_assert_metric("/api/v1/dataset/", table_data, "post") + assert rv.status_code == 422 + data = json.loads(rv.data.decode("utf-8")) + assert data == { + "message": {"table_name": ["Dataset energy_usage already exists"]} + } + + @pytest.mark.usefixtures("load_energy_table_with_slice") + def test_create_dataset_with_sql(self): + """ + Dataset API: Test create dataset with sql + """ + if backend() == "sqlite": + return + + schema = get_example_default_schema() + energy_usage_ds = self.get_energy_usage_dataset() + self.login(username="alpha") + admin = self.get_user("admin") + alpha = self.get_user("alpha") + table_data = { + "database": energy_usage_ds.database_id, + "table_name": "energy_usage_virtual", + "sql": "select * from energy_usage", + "owners": [admin.id], + } + if schema: + table_data["schema"] = schema + rv = self.post_assert_metric("/api/v1/dataset/", table_data, "post") + assert rv.status_code == 201 + data = json.loads(rv.data.decode("utf-8")) + model = db.session.query(SqlaTable).get(data.get("id")) + assert admin in model.owners + assert alpha in model.owners + db.session.delete(model) + db.session.commit() + @unittest.skip("test is failing stochastically") def test_create_dataset_same_name_different_schema(self): if backend() == "sqlite":