Skip to content

Commit

Permalink
Merge pull request #11567 from marcellamaki/content-import-fix
Browse files Browse the repository at this point in the history
Moves setting current_channel to init
  • Loading branch information
bjester authored Nov 30, 2023
2 parents 90df456 + e75855d commit 033b522
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 31 deletions.
100 changes: 73 additions & 27 deletions kolibri/core/content/test/test_channel_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,15 @@ class BaseChannelImportClassConstructorTestCase(TestCase):
"""

def test_channel_id(self, apps_mock, tree_id_mock, BridgeMock):
channel_import = ChannelImport("test", "")
self.assertEqual(channel_import.channel_id, "test")
idValue = uuid.uuid4().hex
channel_import = ChannelImport(idValue, "")
self.assertEqual(channel_import.channel_id, idValue)

@patch("kolibri.core.content.utils.channel_import.get_content_database_file_path")
def test_two_bridges(self, db_path_mock, apps_mock, tree_id_mock, BridgeMock):
db_path_mock.return_value = "test"
ChannelImport("test", "source")
idValue = uuid.uuid4().hex
db_path_mock.return_value = idValue
ChannelImport(idValue, "source")
BridgeMock.assert_has_calls(
[
call(sqlite_file_path="source"),
Expand All @@ -78,7 +80,8 @@ def test_two_bridges(self, db_path_mock, apps_mock, tree_id_mock, BridgeMock):

@patch("kolibri.core.content.utils.channel_import.get_content_database_file_path")
def test_get_config(self, db_path_mock, apps_mock, tree_id_mock, BridgeMock):
ChannelImport("test", "")
idValue = uuid.uuid4().hex
ChannelImport(idValue, "")
apps_mock.assert_has_calls(
[
call.get_app_config("content"),
Expand All @@ -87,7 +90,8 @@ def test_get_config(self, db_path_mock, apps_mock, tree_id_mock, BridgeMock):
)

def test_tree_id(self, apps_mock, tree_id_mock, BridgeMock):
ChannelImport("test", "")
idValue = uuid.uuid4().hex
ChannelImport(idValue, "")
tree_id_mock.assert_called_once_with()


Expand All @@ -103,47 +107,65 @@ class BaseChannelImportClassMethodUniqueTreeIdTestCase(TestCase):

def test_empty(self, apps_mock, tree_ids_mock, BridgeMock):
tree_ids_mock.return_value = []
channel_import = ChannelImport("test", "")
idValue = uuid.uuid4().hex
channel_import = ChannelImport(idValue, "")
self.assertEqual(channel_import.channel_id, idValue)
self.assertEqual(channel_import.find_unique_tree_id(), 1)

def test_one_one(self, apps_mock, tree_ids_mock, BridgeMock):
tree_ids_mock.return_value = [1]
channel_import = ChannelImport("test", "")
idValue = uuid.uuid4().hex
channel_import = ChannelImport(idValue, "")
self.assertEqual(channel_import.channel_id, idValue)
self.assertEqual(channel_import.find_unique_tree_id(), 2)

def test_one_two(self, apps_mock, tree_ids_mock, BridgeMock):
tree_ids_mock.return_value = [2]
channel_import = ChannelImport("test", "")
idValue = uuid.uuid4().hex
channel_import = ChannelImport(idValue, "")
self.assertEqual(channel_import.channel_id, idValue)
self.assertEqual(channel_import.find_unique_tree_id(), 1)

def test_two_one_two(self, apps_mock, tree_ids_mock, BridgeMock):
tree_ids_mock.return_value = [1, 2]
channel_import = ChannelImport("test", "")
idValue = uuid.uuid4().hex
channel_import = ChannelImport(idValue, "")
self.assertEqual(channel_import.channel_id, idValue)
self.assertEqual(channel_import.find_unique_tree_id(), 3)

def test_two_one_three(self, apps_mock, tree_ids_mock, BridgeMock):
tree_ids_mock.return_value = [1, 3]
channel_import = ChannelImport("test", "")
idValue = uuid.uuid4().hex
channel_import = ChannelImport(idValue, "")
self.assertEqual(channel_import.channel_id, idValue)
self.assertEqual(channel_import.find_unique_tree_id(), 2)

def test_three_one_two_three(self, apps_mock, tree_ids_mock, BridgeMock):
tree_ids_mock.return_value = [1, 2, 3]
channel_import = ChannelImport("test", "")
idValue = uuid.uuid4().hex
channel_import = ChannelImport(idValue, "")
self.assertEqual(channel_import.channel_id, idValue)
self.assertEqual(channel_import.find_unique_tree_id(), 4)

def test_three_one_two_four(self, apps_mock, tree_ids_mock, BridgeMock):
tree_ids_mock.return_value = [1, 2, 4]
channel_import = ChannelImport("test", "")
idValue = uuid.uuid4().hex
channel_import = ChannelImport(idValue, "")
self.assertEqual(channel_import.channel_id, idValue)
self.assertEqual(channel_import.find_unique_tree_id(), 3)

def test_three_one_three_four(self, apps_mock, tree_ids_mock, BridgeMock):
tree_ids_mock.return_value = [1, 3, 4]
channel_import = ChannelImport("test", "")
idValue = uuid.uuid4().hex
channel_import = ChannelImport(idValue, "")
self.assertEqual(channel_import.channel_id, idValue)
self.assertEqual(channel_import.find_unique_tree_id(), 2)

def test_three_one_three_five(self, apps_mock, tree_ids_mock, BridgeMock):
tree_ids_mock.return_value = [1, 3, 5]
channel_import = ChannelImport("test", "")
idValue = uuid.uuid4().hex
channel_import = ChannelImport(idValue, "")
self.assertEqual(channel_import.channel_id, idValue)
self.assertEqual(channel_import.find_unique_tree_id(), 2)


Expand All @@ -156,22 +178,28 @@ class BaseChannelImportClassGenRowMapperTestCase(TestCase):
"""

def test_base_mapper(self, apps_mock, tree_id_mock, BridgeMock):
channel_import = ChannelImport("test", "")
idValue = uuid.uuid4().hex
channel_import = ChannelImport(idValue, "")
self.assertEqual(channel_import.channel_id, idValue)
mapper = channel_import.generate_row_mapper()
record = MagicMock()
record.test_attr = "test_val"
self.assertEqual(mapper(record, "test_attr"), "test_val")

def test_column_name_mapping(self, apps_mock, tree_id_mock, BridgeMock):
channel_import = ChannelImport("test", "")
idValue = uuid.uuid4().hex
channel_import = ChannelImport(idValue, "")
self.assertEqual(channel_import.channel_id, idValue)
mappings = {"test_attr": "test_attr_mapped"}
mapper = channel_import.generate_row_mapper(mappings=mappings)
record = MagicMock()
record.test_attr_mapped = "test_val"
self.assertEqual(mapper(record, "test_attr"), "test_val")

def test_method_mapping(self, apps_mock, tree_id_mock, BridgeMock):
channel_import = ChannelImport("test", "")
idValue = uuid.uuid4().hex
channel_import = ChannelImport(idValue, "")
self.assertEqual(channel_import.channel_id, idValue)
mappings = {"test_attr": "test_map_method"}
mapper = channel_import.generate_row_mapper(mappings=mappings)
record = {}
Expand All @@ -181,7 +209,9 @@ def test_method_mapping(self, apps_mock, tree_id_mock, BridgeMock):
self.assertEqual(mapper(record, "test_attr"), "test_val")

def test_no_column_mapping(self, apps_mock, tree_id_mock, BridgeMock):
channel_import = ChannelImport("test", "")
idValue = uuid.uuid4().hex
channel_import = ChannelImport(idValue, "")
self.assertEqual(channel_import.channel_id, idValue)
mappings = {"test_attr": "test_attr_mapped"}
mapper = channel_import.generate_row_mapper(mappings=mappings)
record = Mock(spec=["test_attr"])
Expand All @@ -198,20 +228,26 @@ class BaseChannelImportClassGenTableMapperTestCase(TestCase):
"""

def test_base_mapper(self, apps_mock, tree_id_mock, BridgeMock):
channel_import = ChannelImport("test", "")
idValue = uuid.uuid4().hex
channel_import = ChannelImport(idValue, "")
self.assertEqual(channel_import.channel_id, idValue)
mapper = channel_import.generate_table_mapper()
self.assertEqual(mapper, channel_import.base_table_mapper)

def test_method_mapping(self, apps_mock, tree_id_mock, BridgeMock):
channel_import = ChannelImport("test", "")
idValue = uuid.uuid4().hex
channel_import = ChannelImport(idValue, "")
self.assertEqual(channel_import.channel_id, idValue)
table_map = "test_map_method"
test_map_method = Mock()
channel_import.test_map_method = test_map_method
mapper = channel_import.generate_table_mapper(table_map=table_map)
self.assertEqual(mapper, test_map_method)

def test_no_column_mapping(self, apps_mock, tree_id_mock, BridgeMock):
channel_import = ChannelImport("test", "")
idValue = uuid.uuid4().hex
channel_import = ChannelImport(idValue, "")
self.assertEqual(channel_import.channel_id, idValue)
table_map = "test_map_method"
with self.assertRaises(AttributeError):
channel_import.generate_table_mapper(table_map=table_map)
Expand All @@ -228,7 +264,9 @@ class BaseChannelImportClassTableImportTestCase(TestCase):
def test_no_merge_records_bulk_insert_no_flush(
self, apps_mock, tree_id_mock, BridgeMock
):
channel_import = ChannelImport("test", "")
idValue = uuid.uuid4().hex
channel_import = ChannelImport(idValue, "")
self.assertEqual(channel_import.channel_id, idValue)
record_mock = MagicMock(spec=["__table__"])
record_mock.__table__.columns.items.return_value = [("test_attr", MagicMock())]
channel_import.destination.get_class.return_value = record_mock
Expand All @@ -240,7 +278,9 @@ def test_no_merge_records_bulk_insert_no_flush(
def test_no_merge_records_bulk_insert_flush(
self, apps_mock, tree_id_mock, BridgeMock
):
channel_import = ChannelImport("test", "")
idValue = uuid.uuid4().hex
channel_import = ChannelImport(idValue, "")
self.assertEqual(channel_import.channel_id, idValue)
record_mock = MagicMock(spec=["__table__"])
record_mock.__table__.columns.items.return_value = [("test_attr", MagicMock())]
channel_import.destination.get_class.return_value = record_mock
Expand All @@ -259,7 +299,9 @@ class BaseChannelImportClassOtherMethodsTestCase(TestCase):
"""

def test_import_channel_methods_called(self, apps_mock, tree_id_mock, BridgeMock):
channel_import = ChannelImport("test", "")
idValue = uuid.uuid4().hex
channel_import = ChannelImport(idValue, "")
self.assertEqual(channel_import.channel_id, idValue)
model_mock = Mock(spec=["__name__"])
channel_import.content_models = [model_mock]
mapping_mock = Mock()
Expand All @@ -283,15 +325,19 @@ def test_import_channel_methods_called(self, apps_mock, tree_id_mock, BridgeMock
channel_import.execute_post_operations.assert_called_once()

def test_end(self, apps_mock, tree_id_mock, BridgeMock):
channel_import = ChannelImport("test", "")
idValue = uuid.uuid4().hex
channel_import = ChannelImport(idValue, "")
self.assertEqual(channel_import.channel_id, idValue)
channel_import.end()
channel_import.destination.end.assert_has_calls([call(), call()])

@patch("kolibri.core.content.utils.channel_import.select")
def test_destination_tree_ids(
self, select_mock, apps_mock, tree_id_mock, BridgeMock
):
channel_import = ChannelImport("test", "")
idValue = uuid.uuid4().hex
channel_import = ChannelImport(idValue, "")
self.assertEqual(channel_import.channel_id, idValue)
class_mock = Mock()
channel_import.destination.get_class.return_value = class_mock
channel_import.get_all_destination_tree_ids()
Expand Down
10 changes: 6 additions & 4 deletions kolibri/core/content/utils/channel_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,10 @@ def __init__(
):
self.channel_id = channel_id
self.channel_version = channel_version
try:
self.current_channel = ChannelMetadata.objects.get(id=self.channel_id)
except ChannelMetadata.DoesNotExist:
self.current_channel = None

self.cancel_check = cancel_check

Expand Down Expand Up @@ -951,11 +955,9 @@ def import_channel_data(self):
return import_ran

def run_and_annotate(self):
try:
self.current_channel = ChannelMetadata.objects.get(id=self.channel_id)
if self.current_channel:
old_order = self.current_channel.order
except ChannelMetadata.DoesNotExist:
self.current_channel = None
else:
old_order = None

import_ran = self.import_channel_data()
Expand Down

0 comments on commit 033b522

Please # to comment.