This repository was archived by the owner on Aug 16, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathbot.py
572 lines (488 loc) · 19 KB
/
bot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
# cython: language_level=3
# Copyright (c) 2021-present Pycord Development
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE
import asyncio
from typing import Any, AsyncGenerator, Type, TypeVar
from aiohttp import BasicAuth
from .application_role_connection_metadata import ApplicationRoleConnectionMetadata
from .audit_log import AuditLog
from .commands import Group
from .commands.application.command import ApplicationCommand
from .enums import (
DefaultMessageNotificationLevel,
ExplicitContentFilterLevel,
VerificationLevel,
)
from .errors import BotException, NoIdentifiesLeft, OverfilledShardsException
from .events.event_manager import Event
from .file import File
from .flags import Intents, SystemChannelFlags
from .gateway import PassThrough, ShardCluster, ShardManager
from .guild import Guild, GuildPreview
from .interface import print_banner, start_logging
from .missing import MISSING, Maybe, MissingEnum
from .snowflake import Snowflake
from .state import State
from .types import AsyncFunc
from .types.audit_log import AUDIT_LOG_EVENT_TYPE
from .user import User
from .utils import chunk, get_arg_defaults
T = TypeVar('T')
class Bot:
"""
The class for interacting with Discord with a Bot.
Parameters
----------
intents: :class:`.flags.Intents`
The Gateway Intents to use
print_banner_on_startup
Whether to print the banner on startup or not
logging_flavor: Union[int, str, dict[str, :class:`typing.Any`], None]
The logging flavor this bot uses
Defaults to `None`.
max_messages: :class:`int`
The maximum amount of Messages to cache
shards: :class:`int` | list[:class:`int`]
The amount of shards this bot should launch with.
Defaults to 1.
proxy: :class:`str` | None
The proxy to use
proxy_auth: :class:`aiohttp.BasicAuth` | None
The authentication of your proxy.
Defaults to `None`.
global_shard_status: :class:`int`
The amount of shards globally deployed.
Only supported on bots not using `.cluster`.
Attributes
----------
user: :class:`.user.User`
The user within this bot
guilds: list[:class:`.guild.Guild`]
The Guilds this Bot is in
"""
def __init__(
self,
intents: Intents,
print_banner_on_startup: bool = True,
logging_flavor: int | str | dict[str, Any] | None = None,
max_messages: int = 1000,
shards: int | list[int] | None = None,
global_shard_status: int | None = None,
proxy: str | None = None,
proxy_auth: BasicAuth | None = None,
verbose: bool = False,
) -> None:
self.intents: Intents = intents
self.max_messages: int = max_messages
self._state: State = State(
intents=self.intents, max_messages=self.max_messages, verbose=verbose
)
self._shards = shards
self._logging_flavor: int | str | dict[str, Any] = logging_flavor
self._print_banner = print_banner_on_startup
self._proxy = proxy
self._proxy_auth = proxy_auth
if shards and not global_shard_status:
if isinstance(shards, list):
self._global_shard_status = len(shards)
else:
self._global_shard_status = int(shards)
elif global_shard_status:
self._global_shard_status = global_shard_status
else:
self._global_shard_status = None
@property
def user(self) -> User:
return self._state.user
async def _run_async(self, token: str) -> None:
start_logging(flavor=self._logging_flavor)
self._state.bot_init(
token=token, clustered=False, proxy=self._proxy, proxy_auth=self._proxy_auth
)
info = await self._state.http.get_gateway_bot()
session_start_limit = info['session_start_limit']
self._state.shard_concurrency = PassThrough(
session_start_limit['max_concurrency'], 7
)
self._state._session_start_limit = session_start_limit
if self._shards is None:
shards = list(range(info['shards']))
else:
shards: list[int] = (
self._shards
if isinstance(self._shards, list)
else list(range(self._shards))
)
if session_start_limit['remaining'] == 0:
raise NoIdentifiesLeft('session_start_limit has been exhausted')
elif session_start_limit['remaining'] - len(shards) <= 0:
raise NoIdentifiesLeft('session_start_limit will be exhausted')
sharder = ShardManager(
self._state,
shards,
self._global_shard_status or len(shards),
proxy=self._proxy,
proxy_auth=self._proxy_auth,
)
await sharder.start()
self._state.shard_managers.append(sharder)
while not self._state.raw_user:
self._state._raw_user_fut: asyncio.Future[None] = asyncio.Future()
await self._state._raw_user_fut
if self._print_banner:
printable_shards = 0
if self._shards is None:
printable_shards = len(shards)
else:
printable_shards = (
self._shards if isinstance(self._shards, int) else len(self._shards)
)
print_banner(
self._state._session_start_limit['remaining'],
printable_shards,
bot_name=self.user.name,
)
await self._run_until_exited()
async def _run_until_exited(self) -> None:
try:
await asyncio.Future()
except (asyncio.CancelledError, KeyboardInterrupt):
# most things are already handled by the asyncio.run function
# the only thing we have to worry about are aiohttp errors
await self._state.http.close_session()
for sm in self._state.shard_managers:
await sm.session.close()
if self._state._clustered:
for sc in self._state.shard_clusters:
sc.keep_alive.set_result(None)
def run(self, token: str) -> None:
"""
Run the Bot without being clustered.
.. WARNING::
This blocks permanently and doesn't allow functions after it to run
Parameters
----------
token: :class:`str`
The authentication token of this Bot.
"""
asyncio.run(self._run_async(token=token))
async def _run_cluster(
self, token: str, clusters: int, amount: int, managers: int
) -> None:
start_logging(flavor=self._logging_flavor)
self._state.bot_init(
token=token, clustered=True, proxy=self._proxy, proxy_auth=self._proxy_auth
)
info = await self._state.http.get_gateway_bot()
session_start_limit = info['session_start_limit']
if self._shards is None:
shards = list(range(info['shards']))
else:
shards = (
self._shards
if isinstance(self._shards, list)
else list(range(self._shards))
)
if session_start_limit['remaining'] == 0:
raise NoIdentifiesLeft('session_start_limit has been exhausted')
elif session_start_limit['remaining'] - len(shards) <= 0:
raise NoIdentifiesLeft('session_start_limit will be exhausted')
self._state.shard_concurrency = PassThrough(
session_start_limit['max_concurrency'], 7
)
self._state._session_start_limit = session_start_limit
sorts = list(chunk(shards, clusters))
for cluster in sorts:
cluster_class = ShardCluster(
self._state,
cluster,
amount,
managers,
proxy=self._proxy,
proxy_auth=self._proxy_auth,
)
cluster_class.run()
self._state.shard_clusters.append(cluster_class)
while not self._state.raw_user:
self._state._raw_user_fut: asyncio.Future[None] = asyncio.Future()
await self._state._raw_user_fut
if self._print_banner:
print_banner(
concurrency=self._state._session_start_limit['remaining'],
shard_count=self._shards
if isinstance(self._shards, int)
else len(self._shards),
bot_name=self._state.user.name,
)
await self._run_until_exited()
def cluster(
self,
token: str,
clusters: int,
amount: int | None = None,
managers: int | None = None,
) -> None:
"""
Run the Bot in a clustered formation.
Much more complex but much more scalable.
.. WARNING:: Shouldn't be used on Bots under ~300k Guilds
Parameters
----------
token: :class:`str`
The authentication token of this Bot.
clusters: :class:`int`
The amount of clusters to run.
amount: :class:`int`
The full amount of shards that are/will be running globally (not just on this instance.)
managers: :class:`int` | :class:`int`
The amount of managers to hold per cluster.
Defaults to `None` which automatically determines the amount.
"""
shards = self._shards if isinstance(self._shards, int) else len(self._shards)
if clusters > shards:
raise OverfilledShardsException('Cannot have more clusters than shards')
if not amount:
amount = shards
if not managers:
managers = 1
if amount < shards:
raise OverfilledShardsException(
'Cannot have a higher shard count than shard amount'
)
if managers > shards:
raise OverfilledShardsException('Cannot have more managers than shards')
asyncio.run(
self._run_cluster(
token=token, clusters=clusters, amount=amount, managers=managers
)
)
def listen(self, event: Event | None = None) -> T:
"""
Listen to an event
Parameters
----------
event: :class:`Event` | None
The event to listen to.
Optional if using type hints.
"""
def wrapper(func: T) -> T:
if event:
self._state.event_manager.add_event(event, func)
else:
args = get_arg_defaults(func)
values = list(args.values())
if len(values) != 1:
raise BotException(
'Only one argument is allowed on event functions'
)
eve = values[0]
if eve[1] is None:
raise BotException(
'Event must either be typed, or be present in the `event` parameter'
)
if not isinstance(eve[1](), Event):
raise BotException('Events must be of type Event')
self._state.event_manager.add_event(eve[1], func)
return func
return wrapper
def wait_for(self, event: T) -> asyncio.Future[T]:
return self._state.event_manager.wait_for(event)
def command(
self,
name: str | MissingEnum = MISSING,
cls: T = ApplicationCommand,
**kwargs: Any,
) -> T:
"""
Create a command within the Bot
Parameters
----------
name: :class:`str`
The name of the Command.
cls: type of :class:`.commands.Command`
The command type to instantiate.
kwargs: dict[str, Any]
The kwargs to entail onto the instantiated command.
"""
def wrapper(func: AsyncFunc) -> T:
command = cls(func, name=name, state=self._state, **kwargs)
self._state.commands.append(command)
return command
return wrapper
def group(self, name: str, cls: Type[Group], **kwargs: Any) -> T:
"""
Create a brand-new Group of Commands
Parameters
----------
name: :class:`str`
The name of the Group.
cls: type of :class:`.commands.Group`
The group type to instantiate.
kwargs: dict[str, Any]
The kwargs to entail onto the instantiated group.
"""
def wrapper(func: T) -> T:
return cls(func, name, state=self._state, **kwargs)
return wrapper
@property
async def guilds(self) -> AsyncGenerator[Guild, None]:
return await (self._state.store.sift('guilds')).get_all()
async def get_application_role_connection_metadata_records(
self,
) -> list[ApplicationRoleConnectionMetadata]:
"""Get the application role connection metadata records.
Returns
-------
list[:class:`ApplicationRoleConnectionMetadata`]
The application role connection metadata records.
"""
data = await self._state.http.get_application_role_connection_metadata_records(
self.user.id
)
return [ApplicationRoleConnectionMetadata.from_dict(record) for record in data]
async def update_application_role_connection_metadata_records(
self, records: list[ApplicationRoleConnectionMetadata]
) -> list[ApplicationRoleConnectionMetadata]:
"""Update the application role connection metadata records.
Parameters
----------
records: list[:class:`ApplicationRoleConnectionMetadata`]
The application role connection metadata records.
Returns
-------
list[:class:`ApplicationRoleConnectionMetadata`]
The updated application role connection metadata records.
"""
data = (
await self._state.http.update_application_role_connection_metadata_records(
self.user.id, [record.to_dict() for record in records]
)
)
return [ApplicationRoleConnectionMetadata.from_dict(record) for record in data]
async def create_guild(
self,
name: str,
*,
icon: File | MissingEnum = MISSING,
verification_level: VerificationLevel | MissingEnum = MISSING,
default_message_notifications: DefaultMessageNotificationLevel
| MissingEnum = MISSING,
explicit_content_filter: ExplicitContentFilterLevel | MissingEnum = MISSING,
roles: list[dict] | MissingEnum = MISSING, # TODO
channels: list[dict] | MissingEnum = MISSING, # TODO
afk_channel_id: Snowflake | MissingEnum = MISSING,
afk_timeout: int | MissingEnum = MISSING,
system_channel_id: Snowflake | MissingEnum = MISSING,
system_channel_flags: SystemChannelFlags | MissingEnum = MISSING,
) -> Guild:
"""Create a guild.
Parameters
----------
name: :class:`str`
The name of the guild.
icon: :class:`.File`
The icon of the guild.
verification_level: :class:`VerificationLevel`
The verification level of the guild.
default_message_notifications: :class:`DefaultMessageNotificationLevel`
The default message notifications of the guild.
explicit_content_filter: :class:`ExplicitContentFilterLevel`
The explicit content filter level of the guild.
roles: list[dict]
The roles of the guild.
channels: list[dict]
The channels of the guild.
afk_channel_id: :class:`Snowflake`
The afk channel id of the guild.
afk_timeout: :class:`int`
The afk timeout of the guild.
system_channel_id: :class:`Snowflake`
The system channel id of the guild.
system_channel_flags: :class:`SystemChannelFlags`
The system channel flags of the guild.
Returns
-------
:class:`Guild`
The created guild.
"""
data = await self._state.http.create_guild(
name,
icon=icon,
verification_level=verification_level,
default_message_notifications=default_message_notifications,
explicit_content_filter=explicit_content_filter,
roles=roles,
channels=channels,
afk_channel_id=afk_channel_id,
afk_timeout=afk_timeout,
system_channel_id=system_channel_id,
system_channel_flags=system_channel_flags,
)
return Guild(data, self._state)
async def get_guild(self, guild_id: Snowflake) -> Guild:
"""Get a guild.
Parameters
----------
guild_id: :class:`Snowflake`
The guild id.
Returns
-------
:class:`Guild`
The guild.
"""
data = await self._state.http.get_guild(guild_id)
return Guild(data, self._state)
async def get_guild_preview(self, guild_id: Snowflake) -> GuildPreview:
"""Get a guild preview.
Parameters
----------
guild_id: :class:`Snowflake`
The guild id.
Returns
-------
:class:`GuildPreview`
The guild preview.
"""
data = await self._state.http.get_guild_preview(guild_id)
return GuildPreview(data, self._state)
async def fetch_guild_audit_log(
self,
guild_id: Snowflake,
user_id: Snowflake | MissingEnum = MISSING,
action_type: AUDIT_LOG_EVENT_TYPE | MissingEnum = MISSING,
before: Snowflake | MissingEnum = MISSING,
after: Snowflake | MissingEnum = MISSING,
limit: int | MissingEnum = MISSING,
) -> AuditLog:
"""
Fetches and returns the audit log.
Returns
-------
:class:`.AuditLog`
"""
raw_audit_log = await self._state.http.get_guild_audit_log(
guild_id=guild_id,
user_id=user_id,
action_type=action_type,
before=before,
after=after,
limit=limit,
)
return AuditLog(raw_audit_log, self._state)