Skip to content

Commit 2991b03

Browse files
committed
feat: allow passing upsert_fields with None to SQLizer.upsert_on_duplicate
1 parent f003567 commit 2991b03

File tree

4 files changed

+35
-17
lines changed

4 files changed

+35
-17
lines changed

examples/service/routers/account.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ async def update_last_login_view(
153153
@router.post("/bulk_upsert")
154154
async def bulk_upsert_view():
155155
dicts = []
156-
for (idx, locale) in enumerate([LocaleEnum.ja_JP, LocaleEnum.en_IN, LocaleEnum.zh_CN], 7):
156+
for (idx, locale) in enumerate([LocaleEnum.ja_JP, LocaleEnum.en_IN, LocaleEnum.zh_CN], 10):
157157
faker = Faker(locale)
158158
dicts.append({
159159
"id": idx,
@@ -165,8 +165,7 @@ async def bulk_upsert_view():
165165
row_cnt = await AccountMgr.upsert_on_duplicate(
166166
dicts,
167167
insert_fields=["id", "gender", "name", "locale", "extend"],
168-
upsert_fields=["name", "locale"],
169-
using_values=False,
168+
# upsert_fields=["name", "locale"],
170169
)
171170
return {"row_cnt": row_cnt}
172171

fastapi_esql/orm/base_manager.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ async def upsert_on_duplicate(
149149
cls,
150150
dicts: List[Dict[str, Any]],
151151
insert_fields: List[str],
152-
upsert_fields: List[str],
152+
upsert_fields: Optional[List[str]] = None,
153153
using_values: bool = False,
154154
):
155155
sql = SQLizer.upsert_on_duplicate(

fastapi_esql/utils/sqlizer.py

+14-13
Original file line numberDiff line numberDiff line change
@@ -195,37 +195,38 @@ def upsert_on_duplicate(
195195
table: str,
196196
dicts: List[Dict[str, Any]],
197197
insert_fields: List[str],
198-
upsert_fields: List[str],
198+
upsert_fields: Optional[List[str]] = None,
199199
using_values: bool = False,
200200
) -> Optional[str]:
201-
if not all([table, dicts, insert_fields, upsert_fields]):
202-
raise WrongParamsError("Parameters `table`, `dicts`, `insert_fields`, `upsert_fields` are required")
201+
if not all([table, dicts, insert_fields]):
202+
raise WrongParamsError("Parameters `table`, `dicts`, `insert_fields` are required")
203203

204204
values = [
205205
f" ({', '.join(cls.sqlize_value(d.get(f)) for f in insert_fields)})"
206206
for d in dicts
207207
]
208208
# NOTE Beginning with MySQL 8.0.19, it is possible to use an alias for the row
209209
# https://dev.mysql.com/doc/refman/8.0/en/insert-on-duplicate.html
210-
if using_values:
211-
upserts = [f"{field}=VALUES({field})" for field in upsert_fields]
212-
on_duplicate = f"ON DUPLICATE KEY UPDATE {', '.join(upserts)}"
213-
else:
214-
new_table = f"`new_{table}`"
215-
upserts = [f"{field}={new_table}.{field}" for field in upsert_fields]
216-
on_duplicate = f"AS {new_table} ON DUPLICATE KEY UPDATE {', '.join(upserts)}"
210+
on_duplicate = ""
211+
if upsert_fields:
212+
if using_values:
213+
upserts = [f"{field}=VALUES({field})" for field in upsert_fields]
214+
on_duplicate = f"ON DUPLICATE KEY UPDATE {', '.join(upserts)}"
215+
else:
216+
new_table = f"`new_{table}`"
217+
upserts = [f"{field}={new_table}.{field}" for field in upsert_fields]
218+
on_duplicate = f"AS {new_table} ON DUPLICATE KEY UPDATE {', '.join(upserts)}"
217219

218220
sql = """
219221
INSERT INTO {}
220222
({})
221223
VALUES
222-
{}
223-
{}
224+
{}{}
224225
""".format(
225226
table,
226227
", ".join(insert_fields),
227228
",\n".join(values),
228-
on_duplicate,
229+
f"\n {on_duplicate}" if on_duplicate else "",
229230
)
230231
logger.debug(sql)
231232
return sql

tests/test_sqlizer.py

+18
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,24 @@ def test_upsert_on_duplicate(self):
251251
AS `new_account` ON DUPLICATE KEY UPDATE name=`new_account`.name, locale=`new_account`.locale
252252
"""
253253

254+
only_insert_sql = SQLizer.upsert_on_duplicate(
255+
self.table,
256+
[
257+
{'id': 7, 'gender': 1, 'name': '斉藤 修平', 'locale': 'ja_JP', 'extend': {}},
258+
{'id': 8, 'gender': 1, 'name': 'Ojas Salvi', 'locale': 'en_IN', 'extend': {}},
259+
{'id': 9, 'gender': 1, 'name': '羊淑兰', 'locale': 'zh_CN', 'extend': {}}
260+
],
261+
insert_fields=["id", "gender", "name", "locale", "extend"],
262+
)
263+
assert only_insert_sql == """
264+
INSERT INTO account
265+
(id, gender, name, locale, extend)
266+
VALUES
267+
(7, 1, '斉藤 修平', 'ja_JP', '{}'),
268+
(8, 1, 'Ojas Salvi', 'en_IN', '{}'),
269+
(9, 1, '羊淑兰', 'zh_CN', '{}')
270+
"""
271+
254272
def test_insert_into_select(self):
255273
with self.assertRaises(WrongParamsError):
256274
SQLizer.insert_into_select(

0 commit comments

Comments
 (0)