Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 6c58fac

Browse files
authored
Merge pull request #781 from datafold/normalize_schema_info_databricks_redshift
Normalize schema info databricks redshift
2 parents 7329c43 + 579b82c commit 6c58fac

File tree

3 files changed

+57
-28
lines changed

3 files changed

+57
-28
lines changed

data_diff/databases/databricks.py

+40-17
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class Dialect(BaseDialect):
5353
"TIMESTAMP_NTZ": Timestamp,
5454
# Text
5555
"STRING": Text,
56+
"VARCHAR": Text,
5657
# Boolean
5758
"BOOLEAN": Boolean,
5859
}
@@ -138,25 +139,47 @@ def create_connection(self):
138139
raise ConnectionError(*e.args) from e
139140

140141
def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
141-
# Databricks has INFORMATION_SCHEMA only for Databricks Runtime, not for Databricks SQL.
142-
# https://docs.databricks.com/spark/latest/spark-sql/language-manual/information-schema/columns.html
143-
# So, to obtain information about schema, we should use another approach.
144-
145142
conn = self.create_connection()
143+
table_schema = {}
146144

147-
catalog, schema, table = self._normalize_table_path(path)
148-
with conn.cursor() as cursor:
149-
cursor.columns(catalog_name=catalog, schema_name=schema, table_name=table)
150-
try:
151-
rows = cursor.fetchall()
152-
finally:
153-
conn.close()
154-
if not rows:
155-
raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns")
156-
157-
d = {r.COLUMN_NAME: (r.COLUMN_NAME, r.TYPE_NAME, r.DECIMAL_DIGITS, None, None) for r in rows}
158-
assert len(d) == len(rows)
159-
return d
145+
try:
146+
table_schema = super().query_table_schema(path)
147+
except:
148+
logging.warning("Failed to get schema from information_schema, falling back to legacy approach.")
149+
150+
if not table_schema:
151+
# This legacy approach can cause bugs. e.g. VARCHAR(255) -> VARCHAR(255)
152+
# and not the expected VARCHAR
153+
154+
# I don't think we'll fall back to this approach, but if so, see above
155+
catalog, schema, table = self._normalize_table_path(path)
156+
with conn.cursor() as cursor:
157+
cursor.columns(catalog_name=catalog, schema_name=schema, table_name=table)
158+
try:
159+
rows = cursor.fetchall()
160+
finally:
161+
conn.close()
162+
if not rows:
163+
raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns")
164+
165+
table_schema = {r.COLUMN_NAME: (r.COLUMN_NAME, r.TYPE_NAME, r.DECIMAL_DIGITS, None, None) for r in rows}
166+
assert len(table_schema) == len(rows)
167+
return table_schema
168+
else:
169+
return table_schema
170+
171+
def select_table_schema(self, path: DbPath) -> str:
172+
"""Provide SQL for selecting the table schema as (name, type, date_prec, num_prec)"""
173+
database, schema, name = self._normalize_table_path(path)
174+
info_schema_path = ["information_schema", "columns"]
175+
if database:
176+
info_schema_path.insert(0, database)
177+
178+
return (
179+
"SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale "
180+
f"FROM {'.'.join(info_schema_path)} "
181+
f"WHERE table_name = '{name}' AND table_schema = '{schema}'"
182+
)
160183

161184
def _process_table_schema(
162185
self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str], where: str = None

data_diff/databases/redshift.py

+16-10
Original file line numberDiff line numberDiff line change
@@ -121,15 +121,15 @@ def query_external_table_schema(self, path: DbPath) -> Dict[str, tuple]:
121121
if not rows:
122122
raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns")
123123

124-
d = {r[0]: r for r in rows}
125-
assert len(d) == len(rows)
126-
return d
124+
schema_dict = self._normalize_schema_info(rows)
125+
126+
return schema_dict
127127

128128
def select_view_columns(self, path: DbPath) -> str:
129129
_, schema, table = self._normalize_table_path(path)
130130

131131
return """select * from pg_get_cols('{}.{}')
132-
cols(view_schema name, view_name name, col_name name, col_type varchar, col_num int)
132+
cols(col_name name, col_type varchar)
133133
""".format(schema, table)
134134

135135
def query_pg_get_cols(self, path: DbPath) -> Dict[str, tuple]:
@@ -138,10 +138,17 @@ def query_pg_get_cols(self, path: DbPath) -> Dict[str, tuple]:
138138
if not rows:
139139
raise RuntimeError(f"{self.name}: View '{'.'.join(path)}' does not exist, or has no columns")
140140

141-
output = {}
141+
schema_dict = self._normalize_schema_info(rows)
142+
143+
return schema_dict
144+
145+
# when using a non-information_schema source, strip (N) from type(N) etc. to match
146+
# typical information_schema output
147+
def _normalize_schema_info(self, rows) -> Dict[str, tuple]:
148+
schema_dict = {}
142149
for r in rows:
143-
col_name = r[2]
144-
type_info = r[3].split("(")
150+
col_name = r[0]
151+
type_info = r[1].split("(")
145152
base_type = type_info[0]
146153
precision = None
147154
scale = None
@@ -153,9 +160,8 @@ def query_pg_get_cols(self, path: DbPath) -> Dict[str, tuple]:
153160
scale = int(scale)
154161

155162
out = [col_name, base_type, None, precision, scale]
156-
output[col_name] = tuple(out)
157-
158-
return output
163+
schema_dict[col_name] = tuple(out)
164+
return schema_dict
159165

160166
def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
161167
try:

data_diff/schema.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010

1111
def create_schema(db_name: str, table_path: DbPath, schema: dict, case_sensitive: bool) -> CaseAwareMapping:
12-
logger.debug(f"[{db_name}] Schema = {schema}")
12+
logger.info(f"[{db_name}] Schema = {schema}")
1313

1414
if case_sensitive:
1515
return CaseSensitiveDict(schema)

0 commit comments

Comments
 (0)