aboutsummarybugs & patchesrefslogtreecommitdiffstats
"""
Helper functions and a websocket plugin for interacting with discord
"""

from threading import Event, Thread
from time import monotonic, sleep

from json import loads, dumps, JSONDecodeError

from ws4py.client.threadedclient import WebSocketClient
from ws4py.messaging import TextMessage

import cherrypy

from requests import Request

from . import sendqueue as rl

from . import _values
TOKEN=_values.TOKEN

def log(*msg):
	cherrypy.log(" ".join([ i if type(i)==str else str(i) for i in msg ]),context="DISCORD")

class intents: # intents
	GUILD_MESSAGES = 1<<9

class DiscordWsClientMandatoryAttrs():
	def __init__(self, previous_client=None):
		self.closed_event=Event()

		self.can_resume:bool=True
		self.session_id:str=None
		self.last_s:int=0

		self.registered_functions={}

		if not previous_client:
			self.closed_event.set()
			self.can_resume=False

	def close():
		pass

class DiscordWsClient(WebSocketClient, DiscordWsClientMandatoryAttrs):
	def __init__(self, previous_client, *args, **kwargs):
		WebSocketClient.__init__(self, *args, **kwargs)
		DiscordWsClientMandatoryAttrs.__init__(self, previous_client)

		self.last_s:int=previous_client.last_s
		self.resuming:bool=previous_client.can_resume
		self.resume_url:str=None
		self.session_id:str=previous_client.session_id
		self.last_heartbeat_ack=0

		self.registered_functions:dict[str, list[callable]]=previous_client.registered_functions

	def send(self, payload):
		super().send(dumps(payload))
	
	def opened(self):
		# TODO
		log("Gateway opened")
		pass
	
	def closed(self, code, reason=None):
		if code in (4000, 4001, 4002, 4003, 4005, 4006, 4007, 4008, 4009):
			self.can_resume=True
		else:
			self.can_resume=False
		log("Gateway closed. resumable:",self.can_resume)
		self.closed_event.set()
		pass

	def heartbeat(self, interval:int):
		last_heartbeat=0
		skipped_beats=0
		while not self.closed_event.wait(interval/1000): # milliseconds
			# check if connection is still alive
			if self.last_heartbeat_ack<last_heartbeat: # if no heartbeat was ack'ed after last one
				log(
					"Last gateway heartbeat ack was before last heartbeat. Last ack: %d, last hb: %d"%(
						self.last_heartbeat_ack, last_heartbeat)
				)
				skipped_beats+=1 # we track the amount of beats skipped, to account for unfortunate timing in the scheduler
				if skipped_beats>=2:
					log("2 heartbeats skipped, so probably dead. Closing connection and resuming.")
					self.can_resume=True
					self.close(4000) # any code except 1000/1001 is fine
			else:
				skipped_beats=0

			# send heartbeat
			#log("Heartbeat")
			self.send({"op":1, "d":self.last_s})
			last_heartbeat=monotonic()

	def start_heartbeat(self, interval:int):
		th=Thread(target=self.heartbeat, args=(interval,))
		th.start()

	def received_message(self, m:TextMessage):
		try:
			data=loads(m.data.decode(m.encoding))
		except (UnicodeDecodeError, JSONDecodeError) as e:
			log("Error loading message from gateway: ",e,"\nExact message:",repr(m.data))
			log("Ignoring event.")
			return
		
		try: # one big try-catch block... not amazing
			match data["op"]: # match opcode
				case 0: # Dispatch
					log("Dispatch (op 0):",data["t"])
					self.last_s=data["s"]
					# switch some vars around for ease of usage
					payload=data
					data=payload["d"]

					# match/case the event type and respond accordingly
					match payload["t"]:
						# connection-related events
						case "READY":
							self.resume_url=data["resume_gateway_url"]
							self.session_id=data["session_id"]
							cherrypy.engine.publish("discord-gateway-ready", client=self, data=data, payload=payload)

						case "RESUMED":
							pass

						# misc
						case _:
							self.handle_dispatch(data=data, payload=payload)

				case 1: # Heartbeat
					self.send({"op":1, "d":self.last_s}) # send heartbeat back

				case 10: # Hello
					# start heartbeat loop
					self.start_heartbeat(data["d"]["heartbeat_interval"])

					if self.resuming:
						self.send({
							"op":6,
							"d":{
								"token":TOKEN,
								"session_id":self.session_id,
								"seq":self.last_s
								}
							})
					else:
						# identify
						self.send({
							"op":2,
							"d":{
								"token":TOKEN,
								"intents": intents.GUILD_MESSAGES,
								"properties": {
									"os":"linux",
									"browser":"Discord Image Bridge "+_values.HOMEPAGE,
									"device":"Just another SoC"
									},
								"compress":False,
								"afk":False,
								}
							})

				case 11: # Heartbeat ACK
					self.last_heartbeat_ack=monotonic()

		except Exception:
			import traceback
			tb=traceback.format_exc()
			log("Error processing message:\n"+tb)
	
	def handle_dispatch(self, payload, data):
		for function in self.registered_functions.get(payload["t"], []):
			try:
				function(data=data, payload=payload)
			except Exception:
				import traceback
				log("Error in dispatch handler",function.__name__,"for event",payload["t"],
					":\n"+traceback.format_exc())


	def event(self, events:str|list=None):
		# basically, to use a method as decorator, the method returns a decorator function.
		# due to python scoping being weird, we also have to redefine events under a different
		# name to avoid scope issues
		def decorator(func:callable):
			if type(events)==str:
				event_list=[events]
			else:
				event_list=events

			if not event_list:
				event_list=[func.__name__.upper().removeprefix("ON_")]

			for event in event_list:
				event=event.upper()
				#log("Registring function",func.__name__,"under event",event)
				self.registered_functions.setdefault(event, []).append(func)
		return decorator
	
	def find_handler(self, func):
		for event, funcs in self.registered_functions.items():
			if func in funcs:
				yield event
	
	def remove_handler(self, func, event_list:str|list=None):
		if type(event_list)==str:
			event_list=[event_list]
		if not event_list:
			event_list=self.find_handler(func)

		for event in event_list:
			event=event.upper()
			self.registered_functions.get(event, []).remove(func)
	


class DiscordWsManager(cherrypy.process.plugins.SimplePlugin):
	def __init__(self, bus):
		cherrypy.process.plugins.SimplePlugin.__init__(self, bus)
		self.client:DiscordWsClient=DiscordWsClientMandatoryAttrs()
		self.closing=Event()

	def start(self):
		self.manager_thread=Thread(target=self.manager)
		self.manager_thread.start()
		cherrypy.engine.subscribe("discord-gateway-ready",self.register_event_handlers)

	def manager(self):
		log("Starting discord gateway connection loop")
		while not self.closing.is_set():
			#log(self.client.can_resume)
			if self.client and self.client.can_resume:
				log("Reconnecting to gateway")
				url=self.client.resume_url
			else:
				log("Requesting new gateway URL")
				rq=Request("GET","/gateway/bot")
				resp=rl.sendrequest(rq, 'gateway')
				data=resp.json()

				url=data["url"]
				
				data=data["session_start_limit"]
				log("Reconnects left:",data["remaining"],"; resets after:",data["reset_after"])
				if data["remaining"]<0:
					time_until_reset=data["reset_after"]-rl.time()
					log("Waiting until",data["reset_after"],";",time_until_reset,"seconds left")
					sleep(time_until_reset)
					

			log("Using gateway url:",url)

			self.client=DiscordWsClient(self.client, url)
			self.client.resume_url=url
			self.client.connect()
			self.client.closed_event.wait() # wait for socket to exit
			log("Gateway disconnected.")

	
	def stop(self):
		self.closing.set()
		if self.client:
			self.client.close()
		log("Waiting for gateway socket to close (may take a bit)...")
		self.manager_thread.join()
		cherrypy.engine.unsubscribe("discord-gateway-ready",self.register_event_handlers)
	
	def register_event_handlers(self, client:DiscordWsClient, data, payload):
		from .utils import on_ready
		cherrypy.engine.unsubscribe("discord-gateway-ready",self.register_event_handlers)
		on_ready(self, client)

def _response2return(repl):
	return (repl.status_code, repl.json() if repl.text else None)


def channel_message_send(channel_id, content=None, **params):
	params["content"]=content
	rq=Request("POST",
		"/channels/%s/messages"%str(channel_id),
		json=params
		)
	repl=rl.sendrequest(rq, "channels/%s"%str(channel_id))
	return _response2return(repl)

def channel_message_get(channel_id, message_id):
	rq=Request("GET",
		"/channels/%s/messages/%s"%(str(channel_id), str(message_id))
		)
	repl=rl.sendrequest(rq, "channels/%s"%str(channel_id))
	return _response2return(repl)