From 8b38cb8065d2192495a8e4f7a4bdbf31287a649c Mon Sep 17 00:00:00 2001 From: HibiKier <775757368@qq.com> Date: Tue, 1 Oct 2024 00:04:44 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20=E5=B0=86cd=EF=BC=8Cblock=EF=BC=8Cc?= =?UTF-8?q?ount=E9=99=90=E5=88=B6=E5=A4=8D=E5=8E=9F=E9=85=8D=E7=BD=AE?= =?UTF-8?q?=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../builtin_plugins/hooks/_auth_checker.py | 16 +- zhenxun/builtin_plugins/init/init_config.py | 17 +- zhenxun/builtin_plugins/init/init_plugin.py | 48 +- zhenxun/builtin_plugins/init/manager.py | 417 ++++++++++++++++++ zhenxun/configs/utils/__init__.py | 11 - zhenxun/models/plugin_info.py | 26 ++ zhenxun/models/plugin_limit.py | 2 +- 7 files changed, 491 insertions(+), 46 deletions(-) create mode 100644 zhenxun/builtin_plugins/init/manager.py diff --git a/zhenxun/builtin_plugins/hooks/_auth_checker.py b/zhenxun/builtin_plugins/hooks/_auth_checker.py index df18e6a3d..687fc44d4 100644 --- a/zhenxun/builtin_plugins/hooks/_auth_checker.py +++ b/zhenxun/builtin_plugins/hooks/_auth_checker.py @@ -88,7 +88,7 @@ def unblock( @classmethod async def check( cls, - module: str, + module_path: str, user_id: str, group_id: str | None, channel_id: str | None, @@ -106,17 +106,17 @@ async def check( 异常: IgnoredException: IgnoredException """ - if limit_model := cls.cd_limit.get(module): + if limit_model := cls.cd_limit.get(module_path): await cls.__check(limit_model, user_id, group_id, channel_id, session) - if limit_model := cls.block_limit.get(module): + if limit_model := cls.block_limit.get(module_path): await cls.__check(limit_model, user_id, group_id, channel_id, session) - if limit_model := cls.count_limit.get(module): + if limit_model := cls.count_limit.get(module_path): await cls.__check(limit_model, user_id, group_id, channel_id, session) @classmethod async def __check( cls, - limit_model: Limit, + limit_model: Limit | None, user_id: str, group_id: str | None, channel_id: str | None, @@ -291,12 +291,14 @@ async def auth_limit(self, plugin: PluginInfo, session: EventSession): if not group_id: group_id = channel_id channel_id = None - limit_list: list[PluginLimit] = await plugin.plugin_limit.all() # type: ignore + limit_list: list[PluginLimit] = await plugin.plugin_limit.filter( + status=True + ).all() # type: ignore for limit in limit_list: LimitManage.add_limit(limit) if user_id: await LimitManage.check( - plugin.module, user_id, group_id, channel_id, session + plugin.module_path, user_id, group_id, channel_id, session ) async def auth_plugin( diff --git a/zhenxun/builtin_plugins/init/init_config.py b/zhenxun/builtin_plugins/init/init_config.py index 26534d4d5..a2a7f3324 100644 --- a/zhenxun/builtin_plugins/init/init_config.py +++ b/zhenxun/builtin_plugins/init/init_config.py @@ -1,16 +1,16 @@ from pathlib import Path import nonebot -from nonebot import get_loaded_plugins -from nonebot.drivers import Driver -from nonebot.plugin import Plugin from ruamel.yaml import YAML +from nonebot.plugin import Plugin +from nonebot.drivers import Driver +from nonebot import get_loaded_plugins from ruamel.yaml.comments import CommentedMap +from zhenxun.services.log import logger from zhenxun.configs.config import Config -from zhenxun.configs.path_config import DATA_PATH from zhenxun.configs.utils import RegisterConfig -from zhenxun.services.log import logger +from zhenxun.configs.path_config import DATA_PATH _yaml = YAML(pure=True) _yaml.allow_unicode = True @@ -72,15 +72,14 @@ def _generate_simple_config(): Config.set_config(module, k, _data[module][k]) _tmp_data[module][k] = Config.get_config(module, k) except AttributeError as e: - raise AttributeError(f"{e}\n" + "可能为config.yaml配置文件填写不规范") + raise AttributeError(f"{e}\n可能为config.yaml配置文件填写不规范") from e Config.save() temp_file = DATA_PATH / "temp_config.yaml" # 重新生成简易配置文件 try: with open(temp_file, "w", encoding="utf8") as wf: - # yaml.dump(_tmp_data, wf, Dumper=yaml.RoundTripDumper, allow_unicode=True) _yaml.dump(_tmp_data, wf) - with open(temp_file, "r", encoding="utf8") as rf: + with open(temp_file, encoding="utf8") as rf: _data = _yaml.load(rf) # 添加注释 for module in _data.keys(): @@ -93,7 +92,7 @@ def _generate_simple_config(): with SIMPLE_CONFIG_FILE.open("w", encoding="utf8") as wf: _yaml.dump(_data, wf) except Exception as e: - logger.error(f"生成简易配置注释错误...", e=e) + logger.error("生成简易配置注释错误...", e=e) if temp_file.exists(): temp_file.unlink() diff --git a/zhenxun/builtin_plugins/init/init_plugin.py b/zhenxun/builtin_plugins/init/init_plugin.py index da018eac2..fff731294 100644 --- a/zhenxun/builtin_plugins/init/init_plugin.py +++ b/zhenxun/builtin_plugins/init/init_plugin.py @@ -21,6 +21,8 @@ PluginLimitType, ) +from .manager import manager + _yaml = YAML(pure=True) _yaml.allow_unicode = True _yaml.indent = 2 @@ -148,23 +150,25 @@ async def _(): # ["name", "author", "version", "admin_level", "plugin_type"], # 10, # ) - if limit_list: - limit_create = [] - plugins = [] - if module_path_list := [limit.module_path for limit in limit_list]: - plugins = await PluginInfo.filter(module_path__in=module_path_list).all() - if plugins: - for limit in limit_list: - if lmt := [p for p in plugins if p.module_path == limit.module_path]: - plugin = lmt[0] - limit_type_list = [ - _limit.limit_type for _limit in await plugin.plugin_limit.all() # type: ignore - ] - if limit.limit_type not in limit_type_list: - limit.plugin = plugin - limit_create.append(limit) - if limit_create: - await PluginLimit.bulk_create(limit_create, 10) + # for limit in limit_list: + # limit_create = [] + # plugins = [] + # if module_path_list := [limit.module_path for limit in limit_list]: + # plugins = await PluginInfo.get_plugins(module_path__in=module_path_list) + # if plugins: + # for limit in limit_list: + # if lmt := [p for p in plugins if p.module_path == limit.module_path]: + # plugin = lmt[0] + # """不在数据库中""" + # limit_type_list = [ + # _limit.limit_type + # for _limit in await plugin.plugin_limit.all() # type: ignore + # ] + # if limit.limit_type not in limit_type_list: + # limit.plugin = plugin + # limit_create.append(limit) + # if limit_create: + # await PluginLimit.bulk_create(limit_create, 10) if task_list: module_dict = { t[1]: t[0] for t in await TaskInfo.all().values_list("id", "module") @@ -195,10 +199,18 @@ async def _(): await data_migration() await PluginInfo.filter(module_path__in=load_plugin).update(load_status=True) await PluginInfo.filter(module_path__not_in=load_plugin).update(load_status=False) + manager.init() + if limit_list: + for limit in limit_list: + if not manager.exist(limit.module_path, limit.limit_type): + """不存在,添加""" + manager.add(limit.module_path, limit) + manager.save_file() + await manager.load_to_db() async def data_migration(): - await limit_migration() + # await limit_migration() await plugin_migration() await group_migration() diff --git a/zhenxun/builtin_plugins/init/manager.py b/zhenxun/builtin_plugins/init/manager.py new file mode 100644 index 000000000..05b4779d2 --- /dev/null +++ b/zhenxun/builtin_plugins/init/manager.py @@ -0,0 +1,417 @@ +from copy import deepcopy + +from ruamel.yaml import YAML + +from zhenxun.services.log import logger +from zhenxun.configs.path_config import DATA_PATH +from zhenxun.models.plugin_info import PluginInfo +from zhenxun.models.plugin_limit import PluginLimit +from zhenxun.utils.enum import BlockType, LimitCheckType, PluginLimitType +from zhenxun.configs.utils import BaseBlock, PluginCdBlock, PluginCountBlock + +_yaml = YAML(pure=True) +_yaml.indent = 2 +_yaml.allow_unicode = True + + +CD_TEST = """需要cd的功能 +自定义的功能需要cd也可以在此配置 +key:模块名称 +cd:cd 时长(秒) +status:此限制的开关状态 +check_type:'PRIVATE'/'GROUP'/'ALL',限制私聊/群聊/全部 +watch_type:监听对象,以user_id或group_id作为键来限制,'USER':用户id,'GROUP':群id + 示例:'USER':用户N秒内触发1次,'GROUP':群N秒内触发1次 +result:回复的话,可以添加[at],[uname],[nickname]来对应艾特,用户群名称,昵称系统昵称 +result 为 "" 或 None 时则不回复 +result示例:"[uname]你冲的太快了,[nickname]先生,请稍后再冲[at]" +result回复:"老色批你冲的太快了,欧尼酱先生,请稍后再冲@老色批" + 用户昵称↑ 昵称系统的昵称↑ 艾特用户↑""" + + +BLOCK_TEST = """用户调用阻塞 +即 当用户调用此功能还未结束时 +用发送消息阻止用户重复调用此命令直到该命令结束 +key:模块名称 +status:此限制的开关状态 +check_type:'PRIVATE'/'GROUP'/'ALL',限制私聊/群聊/全部 +watch_type:监听对象,以user_id或group_id作为键来限制,'USER':用户id,'GROUP':群id + 示例:'USER':阻塞用户,'group':阻塞群聊 +result:回复的话,可以添加[at],[uname],[nickname]来对应艾特,用户群名称,昵称系统昵称 +result 为 "" 或 None 时则不回复 +result示例:"[uname]你冲的太快了,[nickname]先生,请稍后再冲[at]" +result回复:"老色批你冲的太快了,欧尼酱先生,请稍后再冲@老色批" + 用户昵称↑ 昵称系统的昵称↑ 艾特用户↑""" + +COUNT_TEST = """命令每日次数限制 +即 用户/群聊 每日可调用命令的次数 [数据内存存储,重启将会重置] +每日调用直到 00:00 刷新 +key:模块名称 +max_count: 每日调用上限 +status:此限制的开关状态 +watch_type:监听对象,以user_id或group_id作为键来限制,'USER':用户id,'GROUP':群id + 示例:'USER':用户上限,'group':群聊上限 +result:回复的话,可以添加[at],[uname],[nickname]来对应艾特,用户群名称,昵称系统昵称 +result 为 "" 或 None 时则不回复 +result示例:"[uname]你冲的太快了,[nickname]先生,请稍后再冲[at]" +result回复:"老色批你冲的太快了,欧尼酱先生,请稍后再冲@老色批" + 用户昵称↑ 昵称系统的昵称↑ 艾特用户↑""" + + +class Manager: + """ + 插件命令 cd 管理器 + """ + + def __init__(self): + self.cd_file = DATA_PATH / "configs" / "plugins2cd.yaml" + self.block_file = DATA_PATH / "configs" / "plugins2block.yaml" + self.count_file = DATA_PATH / "configs" / "plugins2count.yaml" + self.cd_data = {} + self.block_data = {} + self.count_data = {} + + def add( + self, + module_path: str, + data: BaseBlock | PluginCdBlock | PluginCountBlock | PluginLimit, + ): + """添加限制""" + if isinstance(data, PluginLimit): + check_type = BlockType.ALL + if LimitCheckType.GROUP == data.check_type: + check_type = BlockType.GROUP + elif LimitCheckType.PRIVATE == data.check_type: + check_type = BlockType.PRIVATE + if data.limit_type == PluginLimitType.CD: + data = PluginCdBlock( + status=data.status, + check_type=check_type, + watch_type=data.watch_type, + result=data.result, + cd=data.cd, + ) + elif data.limit_type == PluginLimitType.BLOCK: + data = BaseBlock( + status=data.status, + check_type=check_type, + watch_type=data.watch_type, + result=data.result, + ) + elif data.limit_type == PluginLimitType.COUNT: + data = PluginCountBlock( + status=data.status, + watch_type=data.watch_type, + result=data.result, + max_count=data.max_count, + ) + if isinstance(data, PluginCdBlock): + self.cd_data[module_path] = data + elif isinstance(data, PluginCountBlock): + self.count_data[module_path] = data + elif isinstance(data, BaseBlock): + self.block_data[module_path] = data + + def exist(self, module_path: str, type: PluginLimitType): + """是否存在""" + if type == PluginLimitType.CD: + return module_path in self.cd_data + elif type == PluginLimitType.BLOCK: + return module_path in self.block_data + elif type == PluginLimitType.COUNT: + return module_path in self.count_data + + def init(self): + if not self.cd_file.exists(): + self.save_cd_file() + if not self.block_file.exists(): + self.save_block_file() + if not self.count_file.exists(): + self.save_count_file() + self.__load_file() + + def __load_file(self): + self.__load_block_file() + self.__load_cd_file() + self.__load_count_file() + + def save_file(self): + """保存文件""" + self.save_cd_file() + self.save_block_file() + self.save_count_file() + + def save_cd_file(self): + """保存文件""" + self._extracted_from_save_file_3("PluginCdLimit", CD_TEST, self.cd_data) + + def save_block_file(self): + """保存文件""" + self._extracted_from_save_file_3( + "PluginBlockLimit", BLOCK_TEST, self.block_data + ) + + def save_count_file(self): + """保存文件""" + self._extracted_from_save_file_3( + "PluginCountLimit", COUNT_TEST, self.count_data + ) + + def _extracted_from_save_file_3(self, type_: str, after: str, data: dict): + """保存文件 + + 参数: + type_: 类型参数 + after: 备注 + """ + temp_data = deepcopy(data) + if not temp_data: + temp_data = { + "test": { + "status": False, + "check_type": "ALL", + "limit_type": "USER", + "result": "你冲的太快了,请稍后再冲", + } + } + if type_ == "PluginCdLimit": + temp_data["test"]["cd"] = 5 + elif type_ == "PluginCountLimit": + temp_data["test"]["max_count"] = 5 + del temp_data["test"]["check_type"] + else: + for v in temp_data: + temp_data[v] = temp_data[v].dict() + if check_type := temp_data[v].get("check_type"): + temp_data[v]["check_type"] = str(check_type) + if watch_type := temp_data[v].get("watch_type"): + temp_data[v]["watch_type"] = str(watch_type) + if type_ == "PluginCountLimit": + del temp_data[v]["check_type"] + file = self.block_file + if type_ == "PluginCdLimit": + file = self.cd_file + elif type_ == "PluginCountLimit": + file = self.count_file + with open(file, "w", encoding="utf8") as f: + _yaml.dump({type_: temp_data}, f) + with open(file, encoding="utf8") as rf: + _data = _yaml.load(rf) + _data.yaml_set_comment_before_after_key(after=after, key=type_) + with open(file, "w", encoding="utf8") as wf: + _yaml.dump(_data, wf) + + def __load_cd_file(self): + self.cd_data: dict[str, PluginCdBlock] = {} + if self.cd_file.exists(): + with open(self.cd_file, encoding="utf8") as f: + temp = _yaml.load(f) + if "PluginCdLimit" in temp.keys(): + for k, v in temp["PluginCdLimit"].items(): + self.cd_data[k] = PluginCdBlock.parse_obj(v) + + def __load_block_file(self): + self.block_data: dict[str, BaseBlock] = {} + if self.block_file.exists(): + with open(self.block_file, encoding="utf8") as f: + temp = _yaml.load(f) + if "PluginBlockLimit" in temp.keys(): + for k, v in temp["PluginBlockLimit"].items(): + self.block_data[k] = BaseBlock.parse_obj(v) + + def __load_count_file(self): + self.count_data: dict[str, PluginCountBlock] = {} + if self.count_file.exists(): + with open(self.count_file, encoding="utf8") as f: + temp = _yaml.load(f) + if "PluginCountLimit" in temp.keys(): + for k, v in temp["PluginCountLimit"].items(): + self.count_data[k] = PluginCountBlock.parse_obj(v) + + def __replace_data( + self, + db_data: PluginLimit | None, + limit: PluginCdBlock | BaseBlock | PluginCountBlock, + ) -> PluginLimit: + """替换数据""" + if not db_data: + db_data = PluginLimit() + db_data.status = limit.status + check_type = LimitCheckType.ALL + if BlockType.GROUP == limit.check_type: + check_type = LimitCheckType.GROUP + elif BlockType.PRIVATE == limit.check_type: + check_type = LimitCheckType.PRIVATE + db_data.check_type = check_type + db_data.watch_type = limit.watch_type + db_data.result = limit.result or "" + return db_data + + def __set_data( + self, + k: str, + db_data: PluginLimit | None, + limit: PluginCdBlock | BaseBlock | PluginCountBlock, + limit_type: PluginLimitType, + module2plugin: dict[str, PluginInfo], + ) -> tuple[PluginLimit, bool]: + """设置数据 + + 参数: + k: 模块名 + db_data: 数据库数据 + limit: 文件数据 + limit_type: 限制类型 + module2plugin: 模块:插件信息 + + 返回: + tuple[PluginLimit, bool]: PluginLimit,是否创建 + """ + if not db_data: + return ( + PluginLimit( + module=k.split(".")[-1], + module_path=k, + limit_type=limit_type, + plugin=module2plugin.get(k), + cd=getattr(limit, "cd", None), + max_count=getattr(limit, "max_count", None), + status=limit.status, + check_type=limit.check_type, + watch_type=limit.watch_type, + result=limit.result, + ), + True, + ) + db_data = self.__replace_data(db_data, limit) + if limit_type == PluginLimitType.CD: + db_data.cd = limit.cd # type: ignore + if limit_type == PluginLimitType.COUNT: + db_data.max_count = limit.max_count # type: ignore + return db_data, False + + def __get_file_data(self, limit_type: PluginLimitType) -> dict: + """获取文件数据 + + 参数: + limit_type: 限制类型 + + 返回: + dict: 文件数据 + """ + if limit_type == PluginLimitType.CD: + return self.cd_data + elif limit_type == PluginLimitType.COUNT: + return self.count_data + else: + return self.block_data + + def __set_db_limits( + self, + db_limits: list[PluginLimit], + module2plugin: dict[str, PluginInfo], + limit_type: PluginLimitType, + ) -> tuple[list[PluginLimit], list[PluginLimit], list[int]]: + """更新cd限制数据 + + 参数: + db_limits: 数据库limits + module2plugin: 模块:插件信息 + + 返回: + tuple[list[PluginLimit], list[PluginLimit]]: 创建列表,更新列表 + """ + update_list = [] + create_list = [] + delete_list = [] + db_type_limits = [ + limit for limit in db_limits if limit.limit_type == limit_type + ] + if data := self.__get_file_data(limit_type): + db_type_limit_modules = [ + (limit.module_path, limit.id) for limit in db_type_limits + ] + delete_list.extend( + id + for module_path, id in db_type_limit_modules + if module_path not in data.keys() + ) + for k, v in data.items(): + if not module2plugin.get(k): + if k != "test": + logger.warning( + f"插件模块 {k} 未加载,已过滤当前 {v._type} 限制..." + ) + continue + db_data = [limit for limit in db_type_limits if limit.module_path == k] + db_data, is_create = self.__set_data( + k, db_data[0] if db_data else None, v, limit_type, module2plugin + ) + if is_create: + create_list.append(db_data) + else: + update_list.append(db_data) + else: + delete_list = [limit.id for limit in db_type_limits] + return create_list, update_list, delete_list + + async def __set_all_limit( + self, + ) -> tuple[list[PluginLimit], list[PluginLimit], list[int]]: + """获取所有插件限制数据 + + 返回: + tuple[list[PluginLimit], list[PluginLimit]]: 创建列表,更新列表 + """ + db_limits = await PluginLimit.all() + modules = set( + list(self.cd_data.keys()) + + list(self.block_data.keys()) + + list(self.count_data.keys()) + ) + plugins = await PluginInfo.get_plugins(module_path__in=modules) + module2plugin = {p.module_path: p for p in plugins} + create_list, update_list, delete_list = self.__set_db_limits( + db_limits, module2plugin, PluginLimitType.CD + ) + create_list1, update_list1, delete_list1 = self.__set_db_limits( + db_limits, module2plugin, PluginLimitType.COUNT + ) + create_list2, update_list2, delete_list2 = self.__set_db_limits( + db_limits, module2plugin, PluginLimitType.BLOCK + ) + all_create = create_list + create_list1 + create_list2 + all_update = update_list + update_list1 + update_list2 + all_delete = delete_list + delete_list1 + delete_list2 + return all_create, all_update, all_delete + + async def load_to_db(self): + """读取配置文件""" + + create_list, update_list, delete_list = await self.__set_all_limit() + if create_list: + await PluginLimit.bulk_create(create_list) + if update_list: + for limit in update_list: + await limit.save( + update_fields=[ + "status", + "check_type", + "watch_type", + "result", + "cd", + "max_count", + ] + ) + # TODO: tortoise.exceptions.OperationalError:syntax error at or near "GROUP" + # await PluginLimit.bulk_update( + # update_list, + # ["status", "check_type", "watch_type", "result", "cd", "max_count"], + # ) + if delete_list: + await PluginLimit.filter(id__in=delete_list).delete() + cnt = await PluginLimit.filter(status=True).count() + logger.info(f"已经加载 {cnt} 个插件限制.") + + +manager = Manager() diff --git a/zhenxun/configs/utils/__init__.py b/zhenxun/configs/utils/__init__.py index 69d334a14..b041abe30 100644 --- a/zhenxun/configs/utils/__init__.py +++ b/zhenxun/configs/utils/__init__.py @@ -408,13 +408,6 @@ def save(self, path: str | Path | None = None, save_simple_data: bool = False): """ if save_simple_data: with open(self._simple_file, "w", encoding="utf8") as f: - # yaml.dump( - # self._simple_data, - # f, - # indent=2, - # Dumper=yaml.RoundTripDumper, - # allow_unicode=True, - # ) _yaml.dump(self._simple_data, f) path = path or self.file data = {} @@ -426,14 +419,10 @@ def save(self, path: str | Path | None = None, save_simple_data: bool = False): del value["arg_parser"] data[module][config] = value with open(path, "w", encoding="utf8") as f: - # yaml.dump( - # data, f, indent=2, Dumper=yaml.RoundTripDumper, allow_unicode=True - # ) _yaml.dump(data, f) def reload(self): """重新加载配置文件""" - _yaml = YAML() if self._simple_file.exists(): with open(self._simple_file, encoding="utf8") as f: self._simple_data = _yaml.load(f) diff --git a/zhenxun/models/plugin_info.py b/zhenxun/models/plugin_info.py index 964dabad2..e07cb5a9b 100644 --- a/zhenxun/models/plugin_info.py +++ b/zhenxun/models/plugin_info.py @@ -1,3 +1,5 @@ +from typing_extensions import Self + from tortoise import fields from zhenxun.services.db_context import Model @@ -51,6 +53,30 @@ class Meta: # type: ignore table = "plugin_info" table_description = "插件基本信息" + @classmethod + async def get_plugin(cls, load_status: bool = True, **kwargs) -> Self | None: + """获取插件列表 + + 参数: + load_status: 加载状态. + + 返回: + Self | None: 插件 + """ + return await cls.get_or_none(load_status=load_status, **kwargs) + + @classmethod + async def get_plugins(cls, load_status: bool = True, **kwargs) -> list[Self]: + """获取插件列表 + + 参数: + load_status: 加载状态. + + 返回: + list[Self]: 插件列表 + """ + return await cls.filter(load_status=load_status, **kwargs).all() + @classmethod async def _run_script(cls): return [ diff --git a/zhenxun/models/plugin_limit.py b/zhenxun/models/plugin_limit.py index e6b185e73..172c53946 100644 --- a/zhenxun/models/plugin_limit.py +++ b/zhenxun/models/plugin_limit.py @@ -35,6 +35,6 @@ class PluginLimit(Model): max_count = fields.IntField(null=True, description="最大调用次数") """最大调用次数""" - class Meta: + class Meta: # type: ignore table = "plugin_limit" table_description = "插件限制"