diff --git a/src/masoniteorm/models/Model.py b/src/masoniteorm/models/Model.py index 052f08ee..5437935b 100644 --- a/src/masoniteorm/models/Model.py +++ b/src/masoniteorm/models/Model.py @@ -426,12 +426,12 @@ def find(cls, record_id, query=False): builder = cls().where(cls.get_primary_key(), record_id) if query: - return builder.to_sql() - else: - if isinstance(record_id, (list, tuple)): - return builder.get() + return builder + + if isinstance(record_id, (list, tuple)): + return builder.get() - return builder.first() + return builder.first() @classmethod def find_or_fail(cls, record_id, query=False): diff --git a/src/masoniteorm/query/QueryBuilder.py b/src/masoniteorm/query/QueryBuilder.py index 4e9a663b..e6b48129 100644 --- a/src/masoniteorm/query/QueryBuilder.py +++ b/src/masoniteorm/query/QueryBuilder.py @@ -1792,7 +1792,7 @@ def last(self, column=None, query=False): def _get_eager_load_result(self, related, collection): return related.eager_load_from_collection(collection) - def find(self, record_id): + def find(self, record_id, query=False): """Finds a row by the primary key ID. Requires a model Arguments: @@ -1801,8 +1801,12 @@ def find(self, record_id): Returns: Model|None """ + self.where(self._model.get_primary_key(), record_id) - return self.where(self._model.get_primary_key(), record_id).first() + if query: + return self + + return self.first() def find_or(self, record_id: int, callback: Callable, args=None): """Finds a row by the primary key ID (Requires a model) or raise a ModelNotFound exception. @@ -2088,9 +2092,8 @@ def to_sql(self): Returns: self """ - for name, scope in self._global_scopes.get(self._action, {}).items(): - scope(self) + self.run_scopes() grammar = self.get_grammar() sql = grammar.compile(self._action, qmark=False).to_sql() return sql @@ -2117,11 +2120,9 @@ def to_qmark(self): Returns: self """ - for name, scope in self._global_scopes.get(self._action, {}).items(): - scope(self) + self.run_scopes() grammar = self.get_grammar() - sql = grammar.compile(self._action, qmark=True).to_sql() self._bindings = grammar._bindings diff --git a/tests/mysql/scopes/test_soft_delete.py b/tests/mysql/scopes/test_soft_delete.py index b48f829e..1a0c6c43 100644 --- a/tests/mysql/scopes/test_soft_delete.py +++ b/tests/mysql/scopes/test_soft_delete.py @@ -1,8 +1,8 @@ -import inspect import unittest +import pendulum + from tests.integrations.config.database import DATABASES -from src.masoniteorm.models import Model from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import MySQLGrammar from src.masoniteorm.scopes import SoftDeleteScope @@ -10,16 +10,14 @@ from src.masoniteorm.models import Model from src.masoniteorm.scopes import SoftDeletesMixin -from tests.User import User class UserSoft(Model, SoftDeletesMixin): __dry__ = True - + __table__ = "users" class UserSoftArchived(Model, SoftDeletesMixin): __dry__ = True - __deleted_at__ = "archived_at" __table__ = "users" @@ -52,7 +50,7 @@ def test_restore(self): self.assertEqual(sql, builder.restore().to_sql()) def test_force_delete_with_wheres(self): - sql = "DELETE FROM `user_softs` WHERE `user_softs`.`active` = '1'" + sql = "DELETE FROM `users` WHERE `users`.`active` = '1'" builder = self.get_builder().set_global_scope(SoftDeleteScope()) self.assertEqual( sql, UserSoft.where("active", 1).force_delete(query=True).to_sql() @@ -69,9 +67,24 @@ def test_only_trashed(self): self.assertEqual(sql, builder.only_trashed().to_sql()) def test_only_trashed_on_model(self): - sql = "SELECT * FROM `user_softs` WHERE `user_softs`.`deleted_at` IS NOT NULL" + sql = "SELECT * FROM `users` WHERE `users`.`deleted_at` IS NOT NULL" self.assertEqual(sql, UserSoft.only_trashed().to_sql()) def test_can_change_column(self): sql = "SELECT * FROM `users` WHERE `users`.`archived_at` IS NOT NULL" self.assertEqual(sql, UserSoftArchived.only_trashed().to_sql()) + + def test_find_with_global_scope(self): + find_sql = UserSoft.find("1", query=True).to_sql() + raw_sql = """SELECT * FROM `users` WHERE `users`.`id` = '1' AND `users`.`deleted_at` IS NULL""" + self.assertEqual(find_sql, raw_sql) + + def test_find_with_trashed_scope(self): + find_sql = UserSoft.with_trashed().find("1", query=True).to_sql() + raw_sql = """SELECT * FROM `users` WHERE `users`.`id` = '1'""" + self.assertEqual(find_sql, raw_sql) + + def test_find_with_only_trashed_scope(self): + find_sql = UserSoft.only_trashed().find("1", query=True).to_sql() + raw_sql = """SELECT * FROM `users` WHERE `users`.`deleted_at` IS NOT NULL AND `users`.`id` = '1'""" + self.assertEqual(find_sql, raw_sql) diff --git a/tests/sqlite/models/test_sqlite_model.py b/tests/sqlite/models/test_sqlite_model.py index c833c113..8ace14ec 100644 --- a/tests/sqlite/models/test_sqlite_model.py +++ b/tests/sqlite/models/test_sqlite_model.py @@ -72,12 +72,10 @@ def test_update_all_records(self): self.assertEqual(sql, """UPDATE "users" SET "name" = 'joe'""") def test_can_find_list(self): - sql = User.find(1, query=True) - + sql = User.find(1, query=True).to_sql() self.assertEqual(sql, """SELECT * FROM "users" WHERE "users"."id" = '1'""") - sql = User.find([1, 2, 3], query=True) - + sql = User.find([1, 2, 3], query=True).to_sql() self.assertEqual( sql, """SELECT * FROM "users" WHERE "users"."id" IN ('1','2','3')""" )