aboutsummarybugs & patchesrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/discord_image_bridge/_values.py43
-rw-r--r--src/discord_image_bridge/discord.py324
-rw-r--r--src/discord_image_bridge/sendqueue.py124
3 files changed, 491 insertions, 0 deletions
diff --git a/src/discord_image_bridge/_values.py b/src/discord_image_bridge/_values.py
new file mode 100644
index 0000000..2b2ed09
--- /dev/null
+++ b/src/discord_image_bridge/_values.py
@@ -0,0 +1,43 @@
+
+import os
+
+__all__ = [
+ "API_VERSION",
+ "BASEURL",
+ "TOKEN",
+ "HOMEPAGE",
+ "VERSION",
+ "USERAGENT",
+ ]
+
+
+API_VERSION=10
+BASEURL="https://discord.com/api/v%d/"%API_VERSION
+
+TOKEN=os.getenv("TOKEN")
+
+from importlib.metadata import distribution as _distribution
+
+_dist=_distribution(__package__)
+
+# try to determine package homepage
+_urls={}
+for _value in _dist.metadata.get_all("Project-URL"):
+ _name,_url=_value.split(", ")
+ _urls[_name]=_url
+
+for _possible in ["Homepage","Repository","Documentation"]:
+ if _possible in _urls:
+ HOMEPAGE=_urls[_possible]
+ break
+else:
+ HOMEPAGE=_urls.values()[0]
+
+del _possible, _urls, _value, _name, _url, _distribution # clean up
+
+VERSION=_dist.version
+
+del _dist
+
+USERAGENT="DiscordBot (%s, %s)"%(HOMEPAGE, VERSION)
+
diff --git a/src/discord_image_bridge/discord.py b/src/discord_image_bridge/discord.py
new file mode 100644
index 0000000..1257bc4
--- /dev/null
+++ b/src/discord_image_bridge/discord.py
@@ -0,0 +1,324 @@
+"""
+Helper functions and a websocket plugin for interacting with discord
+"""
+
+from threading import Event, Thread
+from time import monotonic, sleep
+
+from json import loads, dumps, JSONDecodeError
+
+from ws4py.client.threadedclient import WebSocketClient
+from ws4py.messaging import TextMessage
+
+import cherrypy
+
+from requests import Request
+
+from . import sendqueue as rl
+from .dbpool import DBPoolManager
+
+from . import _values
+TOKEN=_values.TOKEN
+
+def log(*msg):
+ cherrypy.log(" ".join([ i if type(i)==str else str(i) for i in msg ]),context="DISCORD")
+
+class intents: # intents
+ GUILD_MESSAGES = 1<<9
+
+class DiscordWsClientMandatoryAttrs():
+ def __init__(self, previous_client=None):
+ self.closed_event=Event()
+
+ self.can_resume:bool=True
+ self.session_id:str=None
+ self.last_s:int=0
+
+ self.registered_functions={}
+
+ if not previous_client:
+ self.closed_event.set()
+ self.can_resume=False
+
+ def close():
+ pass
+
+class DiscordWsClient(WebSocketClient, DiscordWsClientMandatoryAttrs):
+ def __init__(self, previous_client, *args, **kwargs):
+ WebSocketClient.__init__(self, *args, **kwargs)
+ DiscordWsClientMandatoryAttrs.__init__(self, previous_client)
+
+ self.last_s:int=previous_client.last_s
+ self.resuming:bool=previous_client.can_resume
+ self.resume_url:str=None
+ self.session_id:str=previous_client.session_id
+ self.last_heartbeat_ack=0
+
+ self.registered_functions:dict[str, list[callable]]=previous_client.registered_functions
+
+ def send(self, payload):
+ super().send(dumps(payload))
+
+ def opened(self):
+ # TODO
+ log("Gateway opened")
+ pass
+
+ def closed(self, code, reason=None):
+ if code in (4000, 4001, 4002, 4003, 4005, 4006, 4007, 4008, 4009):
+ self.can_resume=True
+ else:
+ self.can_resume=False
+ log("Gateway closed. resumable:",self.can_resume)
+ self.closed_event.set()
+ pass
+
+ def heartbeat(self, interval:int):
+ last_heartbeat=0
+ skipped_beats=0
+ while not self.closed_event.wait(interval/1000): # milliseconds
+ # check if connection is still alive
+ if self.last_heartbeat_ack<last_heartbeat: # if no heartbeat was ack'ed after last one
+ log(
+ "Last gateway heartbeat ack was before last heartbeat. Last ack: %d, last hb: %d"%(
+ self.last_heartbeat_ack, last_heartbeat)
+ )
+ skipped_beats+=1 # we track the amount of beats skipped, to account for unfortunate timing in the scheduler
+ if skipped_beats>=2:
+ log("2 heartbeats skipped, so probably dead. Closing connection and resuming.")
+ self.can_resume=True
+ self.close(4000) # any code except 1000/1001 is fine
+ else:
+ skipped_beats=0
+
+ # send heartbeat
+ #log("Heartbeat")
+ self.send({"op":1, "d":self.last_s})
+ last_heartbeat=monotonic()
+
+ def start_heartbeat(self, interval:int):
+ th=Thread(target=self.heartbeat, args=(interval,))
+ th.start()
+
+ def received_message(self, m:TextMessage):
+ try:
+ data=loads(m.data.decode(m.encoding))
+ except (UnicodeDecodeError, JSONDecodeError) as e:
+ log("Error loading message from gateway: ",e,"\nExact message:",repr(m.data))
+ log("Ignoring event.")
+ return
+
+ try: # one big try-catch block... not amazing
+ match data["op"]: # match opcode
+ case 0: # Dispatch
+ log("Dispatch (op 0):",data["t"])
+ self.last_s=data["s"]
+ # switch some vars around for ease of usage
+ payload=data
+ data=payload["d"]
+
+ # match/case the event type and respond accordingly
+ match payload["t"]:
+ # connection-related events
+ case "READY":
+ self.resume_url=data["resume_gateway_url"]
+ self.session_id=data["session_id"]
+ cherrypy.engine.publish("discord-gateway-ready", client=self, data=data, payload=payload)
+
+ case "RESUMED":
+ pass
+
+ # misc
+ case _:
+ self.handle_dispatch(data=data, payload=payload)
+
+ case 1: # Heartbeat
+ self.send({"op":1, "d":self.last_s}) # send heartbeat back
+
+ case 10: # Hello
+ # start heartbeat loop
+ self.start_heartbeat(data["d"]["heartbeat_interval"])
+
+ if self.resuming:
+ self.send({
+ "op":6,
+ "d":{
+ "token":TOKEN,
+ "session_id":self.session_id,
+ "seq":self.last_s
+ }
+ })
+ else:
+ # identify
+ self.send({
+ "op":2,
+ "d":{
+ "token":TOKEN,
+ "intents": intents.GUILD_MESSAGES,
+ "properties": {
+ "os":"linux",
+ "browser":"Discord Image Bridge "+_values.HOMEPAGE,
+ "device":"Just another SoC"
+ },
+ "compress":False,
+ "afk":False,
+ }
+ })
+
+ case 11: # Heartbeat ACK
+ self.last_heartbeat_ack=monotonic()
+
+ except Exception:
+ import traceback
+ tb=traceback.format_exc()
+ log("Error processing message:\n"+tb)
+
+ def handle_dispatch(self, payload, data):
+ for function in self.registered_functions.get(payload["t"], []):
+ try:
+ function(data=data, payload=payload)
+ except Exception:
+ import traceback
+ log("Error in dispatch handler",function.__name__,"for event",payload["t"],
+ ":\n"+traceback.format_exc())
+
+
+ def event(self, events:str|list=None):
+ # basically, to use a method as decorator, the method returns a decorator function.
+ # due to python scoping being weird, we also have to redefine events under a different
+ # name to avoid scope issues
+ def decorator(func:callable):
+ if type(events)==str:
+ event_list=[events]
+ else:
+ event_list=events
+
+ if not event_list:
+ event_list=[func.__name__.upper().removeprefix("ON_")]
+
+ for event in event_list:
+ event=event.upper()
+ log("Registring function",func.__name__,"under event",event)
+ self.registered_functions.setdefault(event, []).append(func)
+ return decorator
+
+ def find_handler(self, func):
+ for event, funcs in self.registered_functions.items():
+ if func in funcs:
+ yield event
+
+ def remove_handler(self, func, event_list:str|list=None):
+ if type(event_list)==str:
+ event_list=[event_list]
+ if not event_list:
+ event_list=self.find_handler(func)
+
+ for event in event_list:
+ event=event.upper()
+ self.registered_functions.get(event, []).remove(func)
+
+
+
+class DiscordWsManager(cherrypy.process.plugins.SimplePlugin):
+ def __init__(self, bus, dbpool:DBPoolManager):
+ cherrypy.process.plugins.SimplePlugin.__init__(self, bus)
+ self.client:DiscordWsClient=DiscordWsClientMandatoryAttrs()
+ self.closing=Event()
+
+ self.dbpool=dbpool
+
+ def start(self):
+ self.manager_thread=Thread(target=self.manager)
+ self.manager_thread.start()
+ cherrypy.engine.subscribe("discord-gateway-ready",self.register_event_handlers)
+
+ def manager(self):
+ log("Starting discord gateway connection loop")
+ while not self.closing.is_set():
+ #log(self.client.can_resume)
+ if self.client and self.client.can_resume:
+ log("Reconnecting to gateway")
+ url=self.client.resume_url
+ else:
+ log("Requesting new gateway URL")
+ rq=Request("GET","/gateway/bot")
+ resp=rl.sendrequest(rq, 'gateway')
+ data=resp.json()
+
+ url=data["url"]
+
+ data=data["session_start_limit"]
+ log("Reconnects left:",data["remaining"],"; resets after:",data["reset_after"])
+ if data["remaining"]<0:
+ time_until_reset=data["reset_after"]-rl.time()
+ log("Waiting until",data["reset_after"],";",time_until_reset,"seconds left")
+ sleep(time_until_reset)
+
+
+ log("Using gateway url:",url)
+
+ self.client=DiscordWsClient(self.client, url)
+ self.client.resume_url=url
+ self.client.connect()
+ self.client.closed_event.wait() # wait for socket to exit
+ log("Gateway disconnected.")
+
+
+ def stop(self):
+ self.closing.set()
+ if self.client:
+ self.client.close()
+ log("Waiting for gateway socket to close (may take a bit)...")
+ self.manager_thread.join()
+ cherrypy.engine.unsubscribe("discord-gateway-ready",self.register_event_handlers)
+
+ def register_event_handlers(self, client:DiscordWsClient, data, payload):
+ from .utils import on_ready
+ cherrypy.engine.unsubscribe("discord-gateway-ready",self.register_event_handlers)
+ on_ready(self, client)
+
+def _response2return(repl):
+ return (repl.status_code, repl.json() if repl.text else None)
+
+def interaction_response_send_message(ctx, content, ephemeral=True, **kwargs):
+ params={
+ "type":4, # CHANNEL_MESSAGE_WITH_SOURCE
+ "data":{
+ "content":content,
+ }
+ }
+ params["data"].update(kwargs)
+
+ if ephemeral:
+ params["data"]["flags"]=1<<6 # EPHEMERAL
+
+ rq=Request("POST",
+ "/interactions/%s/%s/callback"%(
+ ctx["id"],
+ ctx["token"]),
+ json=params
+ )
+
+ repl=rl.sendrequest(rq, "interactions")
+ return _response2return(repl)
+
+def channel_thread_create(channel_id, name, **params):
+ params["name"]=name
+
+ rq=Request("POST",
+ "/channels/%s/threads"%str(channel_id),
+ json=params
+ )
+ repl=rl.sendrequest(rq, "channels/%s"%str(channel_id))
+ return _response2return(repl)
+
+def channel_message_send(channel_id, content=None, **params):
+ params["content"]=content
+ rq=Request("POST",
+ "/channels/%s/messages"%str(channel_id),
+ json=params
+ )
+ repl=rl.sendrequest(rq, "channels/%s"%str(channel_id))
+ return _response2return(repl)
+
+
diff --git a/src/discord_image_bridge/sendqueue.py b/src/discord_image_bridge/sendqueue.py
new file mode 100644
index 0000000..ca18f8b
--- /dev/null
+++ b/src/discord_image_bridge/sendqueue.py
@@ -0,0 +1,124 @@
+"""
+Implements handling ratelimiting on requests. Also adds required headers (auth
+etc) to the request.
+Does not actually formulate any requests, only sends them and makes sure to
+avoid hitting ratelimits
+
+Note the contents of this module are to be used for internal purposes, you
+should never have to interact with the classes and functions in here directly.
+
+Implementation details:
+ we avoid hitting ratelimits by keeping counters based on the X-RateLimit-*
+ headers and avoiding hitting them.
+
+ we avoid getting cloudflare bans by keeping a Counter of errors per minute,
+ and then checking the total amount of errors in the last 10 minutes before
+ sending a request, waiting while it exceeds 9000 (way below the 10000 limit).
+"""
+
+from time import sleep
+from datetime import datetime, timezone
+from collections import Counter
+
+from requests import Request, Session, Response
+
+from . import _values
+
+__all__=["RateLimitHandler"]
+
+def time():
+ "Private function. Get time in UTC, no matter our own timezone, using datetime."
+ datetime.now(tz=timezone.utc)
+ return datetime.timestamp()
+
+class RateLimitHandler():
+ """From the module description:
+ Note the contents of this module are to be used for internal purposes, you
+ should never have to interact with [this class] directly.
+ See module description for details.
+ """
+ def __init__(self):
+ self.global_left=1
+ self.global_retry_after=0
+
+ self.retry_after:dict[str,float] = dict()
+ self.bucket_left:dict[str,float] = dict()
+ self.bucket_keys:dict[str,str] = dict()
+ self.session=Session()
+
+ # for cloudflare ban avoidance.
+ self.invalid_requests_per_minute = Counter() # for tracking the hard 10000 err/minute limit
+
+ def __call__(self,request:Request, routekey, _retries=0):
+ "Sends a request to discord, taking ratelimiting into account."
+ self.avoid_cloudflare_ban() # avoid the 1h cloudflare ban on too many errors
+
+ if not request.url.startswith(_values.BASEURL):
+ request.url=_values.BASEURL+request.url
+
+ request.headers={ # add required headers for authentication
+ "Authorization":"Bot "+_values.TOKEN,
+ "User-Agent":_values.USERAGENT,
+ "Content-Type":"application/json",
+ "Accept":"application/json"
+ }
+
+ sendable=self.session.prepare_request(request)
+
+ if self.global_left <= 0 and time()<self.global_retry_after: # avoid passing global ratelimit
+ #self.client.log("Waiting for global ratelimit",scope="HTTP",severity=_values.log_lvl.debug) # TODO: fix logging
+ sleep(self.global_retry_after - time())
+
+ bucketkey=self.bucket_keys.get(request.method+routekey)
+ # avoid passing bucket ratelimit
+ if bucketkey and self.bucket_left[bucketkey] <= 0 and time()<self.retry_after[bucketkey]:
+ #self.client.log("Waiting for bucket ratelimit",scope="HTTP",severity=_values.log_lvl.debug) # TODO: fix logging
+ self.sleep(self.retry_after[bucketkey] - time())
+
+ # send request
+ resp = self.session.send(sendable)
+
+ # update ratelimit values
+ if "X-RateLimit-Global" in resp.headers:
+ self.global_left = int(resp.headers.get("X-RateLimit-Remaining", 0))
+ self.global_retry_after = float(resp.headers.get("X-RateLimit-Reset", 0))
+
+ elif not resp.headers.get("X-RateLimit-Scope")=="shared":
+ bucketkey=resp.headers.get("X-RateLimit-Bucket", bucketkey)
+ if bucketkey:
+ self.bucket_left[bucketkey] = int(resp.headers.get("X-RateLimit-Remaining",0))
+ self.retry_after[bucketkey] = float(resp.headers.get("X-RateLimit-Reset", 0))
+
+ if resp.status_code == 429: # if we hit a ratelimit anyways, recall self to retry later
+ #self.client.log("Ratelimit hit, data updated, retrying soon.", scope="HTTP", severity=_values.log_lvl.warning) # TODO: fix logging
+ return self.__call__(request=request, routekey=routekey, _retries=_retries+1)
+
+ return resp
+
+ def _get_minutes_since_epoch(self):
+ "private method. returns minutes since epoch."
+ return datetime.now().replace(second=0, microsecond=0).timestamp//60 # get minutes since epoch
+
+ def update_invalid_per_minute(self, resp:Response):
+ "update error counts for cloudflare ban avoidance"
+ minute = self._get_minutes_since_epoch()
+ if resp.status_code in (401, 403, 429):
+ self.invalid_requests_per_minute[minute]+=1
+
+ def avoid_cloudflare_ban(self):
+ """
+ makes sure we don't trigger a cloudflare ban by waiting until the amount
+ of http errors in the last 10 minutes stays below 9000
+ """
+ while self.invalid_requests_per_minute.total()>9000: # stay way below the limit of 10000
+ minute = self._get_minutes_since_epoch()
+ # remove minutes that were more than 10 minutes ago
+ for key in self.invalid_requests_per_minute.keys(): # should be a short iteration, we should have at max 10-12 items
+ if key<minute-10:
+ self.invalid_requests_per_minute.pop(key)
+ sleep(30)
+
+
+sendrequest=RateLimitHandler()
+
+