aboutsummarybugs & patchesrefslogtreecommitdiffstats
path: root/src/discord_image_bridge/discord.py
blob: 2bfaca55226dd7ab76ed224d7b5c683c3b497ff2 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
"""
Helper functions and a websocket plugin for interacting with discord
"""

from threading import Event, Thread
from time import monotonic, sleep

from json import loads, dumps, JSONDecodeError

from ws4py.client.threadedclient import WebSocketClient
from ws4py.messaging import TextMessage

import cherrypy

from requests import Request

from . import sendqueue as rl

from . import _values
TOKEN=_values.TOKEN

def log(*msg):
	cherrypy.log(" ".join([ i if type(i)==str else str(i) for i in msg ]),context="DISCORD")

class intents: # intents
	GUILD_MESSAGES = 1<<9

class DiscordWsClientMandatoryAttrs():
	def __init__(self, previous_client=None):
		self.closed_event=Event()

		self.can_resume:bool=True
		self.session_id:str=None
		self.last_s:int=0

		self.registered_functions={}

		if not previous_client:
			self.closed_event.set()
			self.can_resume=False

	def close():
		pass

class DiscordWsClient(WebSocketClient, DiscordWsClientMandatoryAttrs):
	def __init__(self, previous_client, *args, **kwargs):
		WebSocketClient.__init__(self, *args, **kwargs)
		DiscordWsClientMandatoryAttrs.__init__(self, previous_client)

		self.last_s:int=previous_client.last_s
		self.resuming:bool=previous_client.can_resume
		self.resume_url:str=None
		self.session_id:str=previous_client.session_id
		self.last_heartbeat_ack=0

		self.registered_functions:dict[str, list[callable]]=previous_client.registered_functions

	def send(self, payload):
		super().send(dumps(payload))
	
	def opened(self):
		# TODO
		log("Gateway opened")
		pass
	
	def closed(self, code, reason=None):
		if code in (4000, 4001, 4002, 4003, 4005, 4006, 4007, 4008, 4009):
			self.can_resume=True
		else:
			self.can_resume=False
		log("Gateway closed. resumable:",self.can_resume)
		self.closed_event.set()
		pass

	def heartbeat(self, interval:int):
		last_heartbeat=0
		skipped_beats=0
		while not self.closed_event.wait(interval/1000): # milliseconds
			# check if connection is still alive
			if self.last_heartbeat_ack<last_heartbeat: # if no heartbeat was ack'ed after last one
				log(
					"Last gateway heartbeat ack was before last heartbeat. Last ack: %d, last hb: %d"%(
						self.last_heartbeat_ack, last_heartbeat)
				)
				skipped_beats+=1 # we track the amount of beats skipped, to account for unfortunate timing in the scheduler
				if skipped_beats>=2:
					log("2 heartbeats skipped, so probably dead. Closing connection and resuming.")
					self.can_resume=True
					self.close(4000) # any code except 1000/1001 is fine
			else:
				skipped_beats=0

			# send heartbeat
			#log("Heartbeat")
			self.send({"op":1, "d":self.last_s})
			last_heartbeat=monotonic()

	def start_heartbeat(self, interval:int):
		th=Thread(target=self.heartbeat, args=(interval,))
		th.start()

	def received_message(self, m:TextMessage):
		try:
			data=loads(m.data.decode(m.encoding))
		except (UnicodeDecodeError, JSONDecodeError) as e:
			log("Error loading message from gateway: ",e,"\nExact message:",repr(m.data))
			log("Ignoring event.")
			return
		
		try: # one big try-catch block... not amazing
			match data["op"]: # match opcode
				case 0: # Dispatch
					log("Dispatch (op 0):",data["t"])
					self.last_s=data["s"]
					# switch some vars around for ease of usage
					payload=data
					data=payload["d"]

					# match/case the event type and respond accordingly
					match payload["t"]:
						# connection-related events
						case "READY":
							self.resume_url=data["resume_gateway_url"]
							self.session_id=data["session_id"]
							cherrypy.engine.publish("discord-gateway-ready", client=self, data=data, payload=payload)

						case "RESUMED":
							pass

						# misc
						case _:
							self.handle_dispatch(data=data, payload=payload)

				case 1: # Heartbeat
					self.send({"op":1, "d":self.last_s}) # send heartbeat back

				case 10: # Hello
					# start heartbeat loop
					self.start_heartbeat(data["d"]["heartbeat_interval"])

					if self.resuming:
						self.send({
							"op":6,
							"d":{
								"token":TOKEN,
								"session_id":self.session_id,
								"seq":self.last_s
								}
							})
					else:
						# identify
						self.send({
							"op":2,
							"d":{
								"token":TOKEN,
								"intents": intents.GUILD_MESSAGES,
								"properties": {
									"os":"linux",
									"browser":"Discord Image Bridge "+_values.HOMEPAGE,
									"device":"Just another SoC"
									},
								"compress":False,
								"afk":False,
								}
							})

				case 11: # Heartbeat ACK
					self.last_heartbeat_ack=monotonic()

		except Exception:
			import traceback
			tb=traceback.format_exc()
			log("Error processing message:\n"+tb)
	
	def handle_dispatch(self, payload, data):
		for function in self.registered_functions.get(payload["t"], []):
			try:
				function(data=data, payload=payload)
			except Exception:
				import traceback
				log("Error in dispatch handler",function.__name__,"for event",payload["t"],
					":\n"+traceback.format_exc())


	def event(self, events:str|list=None):
		# basically, to use a method as decorator, the method returns a decorator function.
		# due to python scoping being weird, we also have to redefine events under a different
		# name to avoid scope issues
		def decorator(func:callable):
			if type(events)==str:
				event_list=[events]
			else:
				event_list=events

			if not event_list:
				event_list=[func.__name__.upper().removeprefix("ON_")]

			for event in event_list:
				event=event.upper()
				#log("Registring function",func.__name__,"under event",event)
				self.registered_functions.setdefault(event, []).append(func)
		return decorator
	
	def find_handler(self, func):
		for event, funcs in self.registered_functions.items():
			if func in funcs:
				yield event
	
	def remove_handler(self, func, event_list:str|list=None):
		if type(event_list)==str:
			event_list=[event_list]
		if not event_list:
			event_list=self.find_handler(func)

		for event in event_list:
			event=event.upper()
			self.registered_functions.get(event, []).remove(func)
	


class DiscordWsManager(cherrypy.process.plugins.SimplePlugin):
	def __init__(self, bus):
		cherrypy.process.plugins.SimplePlugin.__init__(self, bus)
		self.client:DiscordWsClient=DiscordWsClientMandatoryAttrs()
		self.closing=Event()

	def start(self):
		self.manager_thread=Thread(target=self.manager)
		self.manager_thread.start()
		cherrypy.engine.subscribe("discord-gateway-ready",self.register_event_handlers)

	def manager(self):
		log("Starting discord gateway connection loop")
		while not self.closing.is_set():
			#log(self.client.can_resume)
			if self.client and self.client.can_resume:
				log("Reconnecting to gateway")
				url=self.client.resume_url
			else:
				log("Requesting new gateway URL")
				rq=Request("GET","/gateway/bot")
				resp=rl.sendrequest(rq, 'gateway')
				data=resp.json()

				url=data["url"]
				
				data=data["session_start_limit"]
				log("Reconnects left:",data["remaining"],"; resets after:",data["reset_after"])
				if data["remaining"]<0:
					time_until_reset=data["reset_after"]-rl.time()
					log("Waiting until",data["reset_after"],";",time_until_reset,"seconds left")
					sleep(time_until_reset)
					

			log("Using gateway url:",url)

			self.client=DiscordWsClient(self.client, url)
			self.client.resume_url=url
			self.client.connect()
			self.client.closed_event.wait() # wait for socket to exit
			log("Gateway disconnected.")

	
	def stop(self):
		self.closing.set()
		if self.client:
			self.client.close()
		log("Waiting for gateway socket to close (may take a bit)...")
		self.manager_thread.join()
		cherrypy.engine.unsubscribe("discord-gateway-ready",self.register_event_handlers)
	
	def register_event_handlers(self, client:DiscordWsClient, data, payload):
		from .utils import on_ready
		cherrypy.engine.unsubscribe("discord-gateway-ready",self.register_event_handlers)
		on_ready(self, client)

def _response2return(repl):
	return (repl.status_code, repl.json() if repl.text else None)


def channel_message_send(channel_id, content=None, **params):
	params["content"]=content
	rq=Request("POST",
		"/channels/%s/messages"%str(channel_id),
		json=params
		)
	repl=rl.sendrequest(rq, "channels/%s"%str(channel_id))
	return _response2return(repl)

def channel_message_get(channel_id, message_id):
	rq=Request("GET",
		"/channels/%s/messages/%s"%(str(channel_id), str(message_id))
		)
	repl=rl.sendrequest(rq, "channels/%s"%str(channel_id))
	return _response2return(repl)