Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Fix find cannot use scopes #893

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/masoniteorm/models/Model.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,12 +426,12 @@ def find(cls, record_id, query=False):
builder = cls().where(cls.get_primary_key(), record_id)

if query:
return builder.to_sql()
else:
if isinstance(record_id, (list, tuple)):
return builder.get()
return builder

if isinstance(record_id, (list, tuple)):
return builder.get()

return builder.first()
return builder.first()

@classmethod
def find_or_fail(cls, record_id, query=False):
Expand Down
15 changes: 8 additions & 7 deletions src/masoniteorm/query/QueryBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1792,7 +1792,7 @@ def last(self, column=None, query=False):
def _get_eager_load_result(self, related, collection):
return related.eager_load_from_collection(collection)

def find(self, record_id):
def find(self, record_id, query=False):
"""Finds a row by the primary key ID. Requires a model

Arguments:
Expand All @@ -1801,8 +1801,12 @@ def find(self, record_id):
Returns:
Model|None
"""
self.where(self._model.get_primary_key(), record_id)

return self.where(self._model.get_primary_key(), record_id).first()
if query:
return self

return self.first()

def find_or(self, record_id: int, callback: Callable, args=None):
"""Finds a row by the primary key ID (Requires a model) or raise a ModelNotFound exception.
Expand Down Expand Up @@ -2088,9 +2092,8 @@ def to_sql(self):
Returns:
self
"""
for name, scope in self._global_scopes.get(self._action, {}).items():
scope(self)

self.run_scopes()
grammar = self.get_grammar()
sql = grammar.compile(self._action, qmark=False).to_sql()
return sql
Expand All @@ -2117,11 +2120,9 @@ def to_qmark(self):
Returns:
self
"""
for name, scope in self._global_scopes.get(self._action, {}).items():
scope(self)

self.run_scopes()
grammar = self.get_grammar()

sql = grammar.compile(self._action, qmark=True).to_sql()

self._bindings = grammar._bindings
Expand Down
27 changes: 20 additions & 7 deletions tests/mysql/scopes/test_soft_delete.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,23 @@
import inspect
import unittest

import pendulum

from tests.integrations.config.database import DATABASES
from src.masoniteorm.models import Model
from src.masoniteorm.query import QueryBuilder
from src.masoniteorm.query.grammars import MySQLGrammar
from src.masoniteorm.scopes import SoftDeleteScope
from tests.utils import MockConnectionFactory

from src.masoniteorm.models import Model
from src.masoniteorm.scopes import SoftDeletesMixin
from tests.User import User


class UserSoft(Model, SoftDeletesMixin):
__dry__ = True

__table__ = "users"

class UserSoftArchived(Model, SoftDeletesMixin):
__dry__ = True

__deleted_at__ = "archived_at"
__table__ = "users"

Expand Down Expand Up @@ -52,7 +50,7 @@ def test_restore(self):
self.assertEqual(sql, builder.restore().to_sql())

def test_force_delete_with_wheres(self):
sql = "DELETE FROM `user_softs` WHERE `user_softs`.`active` = '1'"
sql = "DELETE FROM `users` WHERE `users`.`active` = '1'"
builder = self.get_builder().set_global_scope(SoftDeleteScope())
self.assertEqual(
sql, UserSoft.where("active", 1).force_delete(query=True).to_sql()
Expand All @@ -69,9 +67,24 @@ def test_only_trashed(self):
self.assertEqual(sql, builder.only_trashed().to_sql())

def test_only_trashed_on_model(self):
sql = "SELECT * FROM `user_softs` WHERE `user_softs`.`deleted_at` IS NOT NULL"
sql = "SELECT * FROM `users` WHERE `users`.`deleted_at` IS NOT NULL"
self.assertEqual(sql, UserSoft.only_trashed().to_sql())

def test_can_change_column(self):
sql = "SELECT * FROM `users` WHERE `users`.`archived_at` IS NOT NULL"
self.assertEqual(sql, UserSoftArchived.only_trashed().to_sql())

def test_find_with_global_scope(self):
find_sql = UserSoft.find("1", query=True).to_sql()
raw_sql = """SELECT * FROM `users` WHERE `users`.`id` = '1' AND `users`.`deleted_at` IS NULL"""
self.assertEqual(find_sql, raw_sql)

def test_find_with_trashed_scope(self):
find_sql = UserSoft.with_trashed().find("1", query=True).to_sql()
raw_sql = """SELECT * FROM `users` WHERE `users`.`id` = '1'"""
self.assertEqual(find_sql, raw_sql)

def test_find_with_only_trashed_scope(self):
find_sql = UserSoft.only_trashed().find("1", query=True).to_sql()
raw_sql = """SELECT * FROM `users` WHERE `users`.`deleted_at` IS NOT NULL AND `users`.`id` = '1'"""
self.assertEqual(find_sql, raw_sql)
6 changes: 2 additions & 4 deletions tests/sqlite/models/test_sqlite_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,10 @@ def test_update_all_records(self):
self.assertEqual(sql, """UPDATE "users" SET "name" = 'joe'""")

def test_can_find_list(self):
sql = User.find(1, query=True)

sql = User.find(1, query=True).to_sql()
self.assertEqual(sql, """SELECT * FROM "users" WHERE "users"."id" = '1'""")

sql = User.find([1, 2, 3], query=True)

sql = User.find([1, 2, 3], query=True).to_sql()
self.assertEqual(
sql, """SELECT * FROM "users" WHERE "users"."id" IN ('1','2','3')"""
)
Expand Down
Loading