aboutsummarybugs & patchesrefslogtreecommitdiffstats
path: root/src/discord_image_bridge/sendqueue.py
blob: ca18f8b636cc4357a386220c2bac2da90d6a77a9 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""
Implements handling ratelimiting on requests. Also adds required headers (auth
etc) to the request.
Does not actually formulate any requests, only sends them and makes sure to
avoid hitting ratelimits

Note the contents of this module are to be used for internal purposes, you
should never have to interact with the classes and functions in here directly.

Implementation details:
	we avoid hitting ratelimits by keeping counters based on the X-RateLimit-*
	headers and avoiding hitting them.

	we avoid getting cloudflare bans by keeping a Counter of errors per minute,
	and then checking the total amount of errors in the last 10 minutes before
	sending a request, waiting while it exceeds 9000 (way below the 10000 limit).
"""

from time import sleep
from datetime import datetime, timezone
from collections import Counter

from requests import Request, Session, Response

from . import _values

__all__=["RateLimitHandler"]

def time():
	"Private function. Get time in UTC, no matter our own timezone, using datetime."
	datetime.now(tz=timezone.utc)
	return datetime.timestamp()

class RateLimitHandler():
	"""From the module description:
	Note the contents of this module are to be used for internal purposes, you
	should never have to interact with [this class] directly.
	See module description for details.
	"""
	def __init__(self):
		self.global_left=1
		self.global_retry_after=0

		self.retry_after:dict[str,float] = dict()
		self.bucket_left:dict[str,float] = dict()
		self.bucket_keys:dict[str,str] = dict()
		self.session=Session()

		# for cloudflare ban avoidance.
		self.invalid_requests_per_minute = Counter() # for tracking the hard 10000 err/minute limit
	
	def __call__(self,request:Request, routekey, _retries=0):
		"Sends a request to discord, taking ratelimiting into account."
		self.avoid_cloudflare_ban() # avoid the 1h cloudflare ban on too many errors

		if not request.url.startswith(_values.BASEURL):
			request.url=_values.BASEURL+request.url

		request.headers={ # add required headers for authentication
			"Authorization":"Bot "+_values.TOKEN,
			"User-Agent":_values.USERAGENT,
			"Content-Type":"application/json",
			"Accept":"application/json"
			}

		sendable=self.session.prepare_request(request)

		if self.global_left <= 0 and time()<self.global_retry_after: # avoid passing global ratelimit
			#self.client.log("Waiting for global ratelimit",scope="HTTP",severity=_values.log_lvl.debug) # TODO: fix logging
			sleep(self.global_retry_after - time())

		bucketkey=self.bucket_keys.get(request.method+routekey)
		# avoid passing bucket ratelimit
		if bucketkey and self.bucket_left[bucketkey] <= 0 and time()<self.retry_after[bucketkey]:
			#self.client.log("Waiting for bucket ratelimit",scope="HTTP",severity=_values.log_lvl.debug) # TODO: fix logging
			self.sleep(self.retry_after[bucketkey] - time())

		# send request
		resp = self.session.send(sendable)
		
		# update ratelimit values
		if "X-RateLimit-Global" in resp.headers:
			self.global_left = int(resp.headers.get("X-RateLimit-Remaining", 0))
			self.global_retry_after = float(resp.headers.get("X-RateLimit-Reset", 0))

		elif not resp.headers.get("X-RateLimit-Scope")=="shared":
			bucketkey=resp.headers.get("X-RateLimit-Bucket", bucketkey)
			if bucketkey:
				self.bucket_left[bucketkey] = int(resp.headers.get("X-RateLimit-Remaining",0))
				self.retry_after[bucketkey] = float(resp.headers.get("X-RateLimit-Reset", 0))

		if resp.status_code == 429: # if we hit a ratelimit anyways, recall self to retry later
			#self.client.log("Ratelimit hit, data updated, retrying soon.", scope="HTTP", severity=_values.log_lvl.warning) # TODO: fix logging
			return self.__call__(request=request, routekey=routekey, _retries=_retries+1)

		return resp

	def _get_minutes_since_epoch(self):
		"private method. returns minutes since epoch."
		return datetime.now().replace(second=0, microsecond=0).timestamp//60 # get minutes since epoch

	def update_invalid_per_minute(self, resp:Response):
		"update error counts for cloudflare ban avoidance"
		minute = self._get_minutes_since_epoch()
		if resp.status_code in (401, 403, 429):
			self.invalid_requests_per_minute[minute]+=1
	
	def avoid_cloudflare_ban(self):
		"""
		makes sure we don't trigger a cloudflare ban by waiting until the amount
		of http errors in the last 10 minutes stays below 9000
		"""
		while self.invalid_requests_per_minute.total()>9000: # stay way below the limit of 10000
			minute = self._get_minutes_since_epoch()
			# remove minutes that were more than 10 minutes ago
			for key in self.invalid_requests_per_minute.keys(): # should be a short iteration, we should have at max 10-12 items
				if key<minute-10:
					self.invalid_requests_per_minute.pop(key)
			sleep(30)


sendrequest=RateLimitHandler()