From cc7c053b7bea0e525bcb4069b0c03a9a6461b187 Mon Sep 17 00:00:00 2001 From: Vosjedev Date: Sat, 24 Jan 2026 12:18:56 +0100 Subject: import discord stuff from other projects and tailor it for this one Signed-off-by: Vosjedev --- src/discord_image_bridge/_values.py | 43 +++++ src/discord_image_bridge/discord.py | 324 ++++++++++++++++++++++++++++++++++ src/discord_image_bridge/sendqueue.py | 124 +++++++++++++ 3 files changed, 491 insertions(+) create mode 100644 src/discord_image_bridge/_values.py create mode 100644 src/discord_image_bridge/discord.py create mode 100644 src/discord_image_bridge/sendqueue.py (limited to 'src') diff --git a/src/discord_image_bridge/_values.py b/src/discord_image_bridge/_values.py new file mode 100644 index 0000000..2b2ed09 --- /dev/null +++ b/src/discord_image_bridge/_values.py @@ -0,0 +1,43 @@ + +import os + +__all__ = [ + "API_VERSION", + "BASEURL", + "TOKEN", + "HOMEPAGE", + "VERSION", + "USERAGENT", + ] + + +API_VERSION=10 +BASEURL="https://discord.com/api/v%d/"%API_VERSION + +TOKEN=os.getenv("TOKEN") + +from importlib.metadata import distribution as _distribution + +_dist=_distribution(__package__) + +# try to determine package homepage +_urls={} +for _value in _dist.metadata.get_all("Project-URL"): + _name,_url=_value.split(", ") + _urls[_name]=_url + +for _possible in ["Homepage","Repository","Documentation"]: + if _possible in _urls: + HOMEPAGE=_urls[_possible] + break +else: + HOMEPAGE=_urls.values()[0] + +del _possible, _urls, _value, _name, _url, _distribution # clean up + +VERSION=_dist.version + +del _dist + +USERAGENT="DiscordBot (%s, %s)"%(HOMEPAGE, VERSION) + diff --git a/src/discord_image_bridge/discord.py b/src/discord_image_bridge/discord.py new file mode 100644 index 0000000..1257bc4 --- /dev/null +++ b/src/discord_image_bridge/discord.py @@ -0,0 +1,324 @@ +""" +Helper functions and a websocket plugin for interacting with discord +""" + +from threading import Event, Thread +from time import monotonic, sleep + +from json import loads, dumps, JSONDecodeError + +from ws4py.client.threadedclient import WebSocketClient +from ws4py.messaging import TextMessage + +import cherrypy + +from requests import Request + +from . import sendqueue as rl +from .dbpool import DBPoolManager + +from . import _values +TOKEN=_values.TOKEN + +def log(*msg): + cherrypy.log(" ".join([ i if type(i)==str else str(i) for i in msg ]),context="DISCORD") + +class intents: # intents + GUILD_MESSAGES = 1<<9 + +class DiscordWsClientMandatoryAttrs(): + def __init__(self, previous_client=None): + self.closed_event=Event() + + self.can_resume:bool=True + self.session_id:str=None + self.last_s:int=0 + + self.registered_functions={} + + if not previous_client: + self.closed_event.set() + self.can_resume=False + + def close(): + pass + +class DiscordWsClient(WebSocketClient, DiscordWsClientMandatoryAttrs): + def __init__(self, previous_client, *args, **kwargs): + WebSocketClient.__init__(self, *args, **kwargs) + DiscordWsClientMandatoryAttrs.__init__(self, previous_client) + + self.last_s:int=previous_client.last_s + self.resuming:bool=previous_client.can_resume + self.resume_url:str=None + self.session_id:str=previous_client.session_id + self.last_heartbeat_ack=0 + + self.registered_functions:dict[str, list[callable]]=previous_client.registered_functions + + def send(self, payload): + super().send(dumps(payload)) + + def opened(self): + # TODO + log("Gateway opened") + pass + + def closed(self, code, reason=None): + if code in (4000, 4001, 4002, 4003, 4005, 4006, 4007, 4008, 4009): + self.can_resume=True + else: + self.can_resume=False + log("Gateway closed. resumable:",self.can_resume) + self.closed_event.set() + pass + + def heartbeat(self, interval:int): + last_heartbeat=0 + skipped_beats=0 + while not self.closed_event.wait(interval/1000): # milliseconds + # check if connection is still alive + if self.last_heartbeat_ack=2: + log("2 heartbeats skipped, so probably dead. Closing connection and resuming.") + self.can_resume=True + self.close(4000) # any code except 1000/1001 is fine + else: + skipped_beats=0 + + # send heartbeat + #log("Heartbeat") + self.send({"op":1, "d":self.last_s}) + last_heartbeat=monotonic() + + def start_heartbeat(self, interval:int): + th=Thread(target=self.heartbeat, args=(interval,)) + th.start() + + def received_message(self, m:TextMessage): + try: + data=loads(m.data.decode(m.encoding)) + except (UnicodeDecodeError, JSONDecodeError) as e: + log("Error loading message from gateway: ",e,"\nExact message:",repr(m.data)) + log("Ignoring event.") + return + + try: # one big try-catch block... not amazing + match data["op"]: # match opcode + case 0: # Dispatch + log("Dispatch (op 0):",data["t"]) + self.last_s=data["s"] + # switch some vars around for ease of usage + payload=data + data=payload["d"] + + # match/case the event type and respond accordingly + match payload["t"]: + # connection-related events + case "READY": + self.resume_url=data["resume_gateway_url"] + self.session_id=data["session_id"] + cherrypy.engine.publish("discord-gateway-ready", client=self, data=data, payload=payload) + + case "RESUMED": + pass + + # misc + case _: + self.handle_dispatch(data=data, payload=payload) + + case 1: # Heartbeat + self.send({"op":1, "d":self.last_s}) # send heartbeat back + + case 10: # Hello + # start heartbeat loop + self.start_heartbeat(data["d"]["heartbeat_interval"]) + + if self.resuming: + self.send({ + "op":6, + "d":{ + "token":TOKEN, + "session_id":self.session_id, + "seq":self.last_s + } + }) + else: + # identify + self.send({ + "op":2, + "d":{ + "token":TOKEN, + "intents": intents.GUILD_MESSAGES, + "properties": { + "os":"linux", + "browser":"Discord Image Bridge "+_values.HOMEPAGE, + "device":"Just another SoC" + }, + "compress":False, + "afk":False, + } + }) + + case 11: # Heartbeat ACK + self.last_heartbeat_ack=monotonic() + + except Exception: + import traceback + tb=traceback.format_exc() + log("Error processing message:\n"+tb) + + def handle_dispatch(self, payload, data): + for function in self.registered_functions.get(payload["t"], []): + try: + function(data=data, payload=payload) + except Exception: + import traceback + log("Error in dispatch handler",function.__name__,"for event",payload["t"], + ":\n"+traceback.format_exc()) + + + def event(self, events:str|list=None): + # basically, to use a method as decorator, the method returns a decorator function. + # due to python scoping being weird, we also have to redefine events under a different + # name to avoid scope issues + def decorator(func:callable): + if type(events)==str: + event_list=[events] + else: + event_list=events + + if not event_list: + event_list=[func.__name__.upper().removeprefix("ON_")] + + for event in event_list: + event=event.upper() + log("Registring function",func.__name__,"under event",event) + self.registered_functions.setdefault(event, []).append(func) + return decorator + + def find_handler(self, func): + for event, funcs in self.registered_functions.items(): + if func in funcs: + yield event + + def remove_handler(self, func, event_list:str|list=None): + if type(event_list)==str: + event_list=[event_list] + if not event_list: + event_list=self.find_handler(func) + + for event in event_list: + event=event.upper() + self.registered_functions.get(event, []).remove(func) + + + +class DiscordWsManager(cherrypy.process.plugins.SimplePlugin): + def __init__(self, bus, dbpool:DBPoolManager): + cherrypy.process.plugins.SimplePlugin.__init__(self, bus) + self.client:DiscordWsClient=DiscordWsClientMandatoryAttrs() + self.closing=Event() + + self.dbpool=dbpool + + def start(self): + self.manager_thread=Thread(target=self.manager) + self.manager_thread.start() + cherrypy.engine.subscribe("discord-gateway-ready",self.register_event_handlers) + + def manager(self): + log("Starting discord gateway connection loop") + while not self.closing.is_set(): + #log(self.client.can_resume) + if self.client and self.client.can_resume: + log("Reconnecting to gateway") + url=self.client.resume_url + else: + log("Requesting new gateway URL") + rq=Request("GET","/gateway/bot") + resp=rl.sendrequest(rq, 'gateway') + data=resp.json() + + url=data["url"] + + data=data["session_start_limit"] + log("Reconnects left:",data["remaining"],"; resets after:",data["reset_after"]) + if data["remaining"]<0: + time_until_reset=data["reset_after"]-rl.time() + log("Waiting until",data["reset_after"],";",time_until_reset,"seconds left") + sleep(time_until_reset) + + + log("Using gateway url:",url) + + self.client=DiscordWsClient(self.client, url) + self.client.resume_url=url + self.client.connect() + self.client.closed_event.wait() # wait for socket to exit + log("Gateway disconnected.") + + + def stop(self): + self.closing.set() + if self.client: + self.client.close() + log("Waiting for gateway socket to close (may take a bit)...") + self.manager_thread.join() + cherrypy.engine.unsubscribe("discord-gateway-ready",self.register_event_handlers) + + def register_event_handlers(self, client:DiscordWsClient, data, payload): + from .utils import on_ready + cherrypy.engine.unsubscribe("discord-gateway-ready",self.register_event_handlers) + on_ready(self, client) + +def _response2return(repl): + return (repl.status_code, repl.json() if repl.text else None) + +def interaction_response_send_message(ctx, content, ephemeral=True, **kwargs): + params={ + "type":4, # CHANNEL_MESSAGE_WITH_SOURCE + "data":{ + "content":content, + } + } + params["data"].update(kwargs) + + if ephemeral: + params["data"]["flags"]=1<<6 # EPHEMERAL + + rq=Request("POST", + "/interactions/%s/%s/callback"%( + ctx["id"], + ctx["token"]), + json=params + ) + + repl=rl.sendrequest(rq, "interactions") + return _response2return(repl) + +def channel_thread_create(channel_id, name, **params): + params["name"]=name + + rq=Request("POST", + "/channels/%s/threads"%str(channel_id), + json=params + ) + repl=rl.sendrequest(rq, "channels/%s"%str(channel_id)) + return _response2return(repl) + +def channel_message_send(channel_id, content=None, **params): + params["content"]=content + rq=Request("POST", + "/channels/%s/messages"%str(channel_id), + json=params + ) + repl=rl.sendrequest(rq, "channels/%s"%str(channel_id)) + return _response2return(repl) + + diff --git a/src/discord_image_bridge/sendqueue.py b/src/discord_image_bridge/sendqueue.py new file mode 100644 index 0000000..ca18f8b --- /dev/null +++ b/src/discord_image_bridge/sendqueue.py @@ -0,0 +1,124 @@ +""" +Implements handling ratelimiting on requests. Also adds required headers (auth +etc) to the request. +Does not actually formulate any requests, only sends them and makes sure to +avoid hitting ratelimits + +Note the contents of this module are to be used for internal purposes, you +should never have to interact with the classes and functions in here directly. + +Implementation details: + we avoid hitting ratelimits by keeping counters based on the X-RateLimit-* + headers and avoiding hitting them. + + we avoid getting cloudflare bans by keeping a Counter of errors per minute, + and then checking the total amount of errors in the last 10 minutes before + sending a request, waiting while it exceeds 9000 (way below the 10000 limit). +""" + +from time import sleep +from datetime import datetime, timezone +from collections import Counter + +from requests import Request, Session, Response + +from . import _values + +__all__=["RateLimitHandler"] + +def time(): + "Private function. Get time in UTC, no matter our own timezone, using datetime." + datetime.now(tz=timezone.utc) + return datetime.timestamp() + +class RateLimitHandler(): + """From the module description: + Note the contents of this module are to be used for internal purposes, you + should never have to interact with [this class] directly. + See module description for details. + """ + def __init__(self): + self.global_left=1 + self.global_retry_after=0 + + self.retry_after:dict[str,float] = dict() + self.bucket_left:dict[str,float] = dict() + self.bucket_keys:dict[str,str] = dict() + self.session=Session() + + # for cloudflare ban avoidance. + self.invalid_requests_per_minute = Counter() # for tracking the hard 10000 err/minute limit + + def __call__(self,request:Request, routekey, _retries=0): + "Sends a request to discord, taking ratelimiting into account." + self.avoid_cloudflare_ban() # avoid the 1h cloudflare ban on too many errors + + if not request.url.startswith(_values.BASEURL): + request.url=_values.BASEURL+request.url + + request.headers={ # add required headers for authentication + "Authorization":"Bot "+_values.TOKEN, + "User-Agent":_values.USERAGENT, + "Content-Type":"application/json", + "Accept":"application/json" + } + + sendable=self.session.prepare_request(request) + + if self.global_left <= 0 and time()9000: # stay way below the limit of 10000 + minute = self._get_minutes_since_epoch() + # remove minutes that were more than 10 minutes ago + for key in self.invalid_requests_per_minute.keys(): # should be a short iteration, we should have at max 10-12 items + if key