Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

use REPLACE INTO instead of INSERT INTO...UPDATE in covid_hosp acquisition #1356

Merged
merged 6 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions src/acquisition/covid_hosp/common/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,19 +184,16 @@ def nan_safe_dtype(dtype, value):
for csv_name in self.key_columns:
dataframe.loc[:, csv_name] = dataframe[csv_name].map(self.columns_and_types[csv_name].dtype)

num_columns = 2 + len(dataframe_columns_and_types) + len(self.additional_fields)
value_placeholders = ', '.join(['%s'] * num_columns)
col_names = [f'`{i.sql_name}`' for i in dataframe_columns_and_types + self.additional_fields]
columns = ', '.join(col_names)
updates = ', '.join(f'{c}=new_values.{c}' for c in col_names)
# NOTE: list in `updates` presumes `publication_col_name` is part of the unique key and thus not needed in UPDATE
sql = f'INSERT INTO `{self.table_name}` (`id`, `{self.publication_col_name}`, {columns}) ' \
f'VALUES ({value_placeholders}) AS new_values ' \
f'ON DUPLICATE KEY UPDATE {updates}'
value_placeholders = ', '.join(['%s'] * (2 + len(col_names))) # extra 2 for `id` and `self.publication_col_name` cols
columnstring = ', '.join(col_names)
sql = f'REPLACE INTO `{self.table_name}` (`id`, `{self.publication_col_name}`, {columnstring}) VALUES ({value_placeholders})'
id_and_publication_date = (0, publication_date)
num_values = len(dataframe.index)
if logger:
logger.info('updating values', count=len(dataframe.index))
logger.info('updating values', count=num_values)
n = 0
rows_affected = 0
many_values = []
with self.new_cursor() as cursor:
for index, row in dataframe.iterrows():
Expand All @@ -212,6 +209,7 @@ def nan_safe_dtype(dtype, value):
if n % 5_000 == 0:
try:
cursor.executemany(sql, many_values)
rows_affected += cursor.rowcount
many_values = []
except Exception as e:
if logger:
Expand All @@ -220,6 +218,11 @@ def nan_safe_dtype(dtype, value):
# insert final batch
if many_values:
cursor.executemany(sql, many_values)
rows_affected += cursor.rowcount
if logger:
# NOTE: REPLACE INTO marks 2 rows affected for a "replace" (one for a delete and one for a re-insert)
# which allows us to count rows which were updated
logger.info('rows affected', total=rows_affected, updated=rows_affected-num_values)

# deal with non/seldomly updated columns used like a fk table (if this database needs it)
if hasattr(self, 'AGGREGATE_KEY_COLS'):
Expand Down
2 changes: 1 addition & 1 deletion tests/acquisition/covid_hosp/common/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def test_insert_dataset(self):

actual_sql = mock_cursor.executemany.call_args[0][0]
self.assertIn(
'INSERT INTO `test_table` (`id`, `publication_date`, `sql_str_col`, `sql_int_col`, `sql_float_col`)',
'REPLACE INTO `test_table` (`id`, `publication_date`, `sql_str_col`, `sql_int_col`, `sql_float_col`)',
actual_sql)

expected_values = [
Expand Down