feat: 更新太大了不想总结自己去看提交记录吧(已知把配置文件换成了yaml

This commit is contained in:
helloplhm-qwq
2024-04-14 19:06:36 +08:00
parent 45e2e7147d
commit 667d420499
13 changed files with 432 additions and 412 deletions

149
main.py
View File

@ -1,22 +1,33 @@
#!/usr/bin/env python3
# ----------------------------------------
# - mode: python -
# - author: helloplhm-qwq -
# - name: main.py -
# - project: lx-music-api-server -
# - license: MIT -
# - mode: python -
# - author: helloplhm-qwq -
# - name: main.py -
# - project: lx-music-api-server -
# - license: MIT -
# ----------------------------------------
# This file is part of the "lx-music-api-server" project.
import time
import aiohttp
import asyncio
import traceback
import threading
import ujson as json
from aiohttp.web import Response, FileResponse, StreamResponse
from io import TextIOWrapper
import sys
from common.utils import createBase64Decode
import os
if ((sys.version_info.major == 3 and sys.version_info.minor < 6) or sys.version_info.major == 2):
print('Python版本过低请使用Python 3.6+ ')
sys.exit(1)
# fix: module not found: common/modules
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from common import utils
from common import config, localMusic
from common import lxsecurity
from common import log
@ -24,24 +35,17 @@ from common import Httpx
from common import variable
from common import scheduler
from common import lx_script
from aiohttp.web import Response, FileResponse, StreamResponse
import ujson as json
import threading
import traceback
import modules
import asyncio
import aiohttp
import time
import os
def handleResult(dic, status = 200) -> Response:
def handleResult(dic, status=200) -> Response:
if (not isinstance(dic, dict)):
dic = {
'code': 0,
'msg': 'success',
'data': dic
}
return Response(body = json.dumps(dic, indent=2, ensure_ascii=False), content_type='application/json', status = status)
return Response(body=json.dumps(dic, indent=2, ensure_ascii=False), content_type='application/json', status=status)
logger = log.log("main")
aiologger = log.log('aiohttp_web')
@ -54,20 +58,24 @@ if (sys.version_info.minor < 8 and sys.version_info.major == 3):
else:
stopEvent = asyncio.exceptions.CancelledError
def start_checkcn_thread() -> None:
threading.Thread(target=Httpx.checkcn).start()
# check request info before start
async def handle_before_request(app, handler):
async def handle_request(request):
try:
if (config.read_config('common.reverse_proxy.allow_proxy')):
if (request.headers.get(config.read_config('common.reverse_proxy.real_ip_header'))):
# proxy header
if (request.remote in config.read_config('common.reverse_proxy.proxy_whitelist_remote')):
request.remote_addr = request.headers.get(config.read_config('common.reverse_proxy.real_ip_header'))
if (config.read_config('common.reverse_proxy.allow_public_ip') and (not utils.is_private_ip(request.remote))):
request.remote_addr = request.headers.get(
config.read_config('common.reverse_proxy.real_ip_header'))
else:
return handleResult({"code": 1, "msg": "反代客户端远程地址不在反代ip白名单中", "data": None}, 403)
return handleResult({"code": 1, "msg": "不允许的公网ip转发", "data": None}, 403)
else:
request.remote_addr = request.remote
else:
@ -80,13 +88,13 @@ async def handle_before_request(app, handler):
(time.time() - config.getRequestTime('global'))
<
(config.read_config("security.rate_limit.global"))
):
):
return handleResult({"code": 5, "msg": "全局限速", "data": None}, 429)
if (
(time.time() - config.getRequestTime(request.remote_addr))
<
(config.read_config("security.rate_limit.ip"))
):
):
return handleResult({"code": 5, "msg": "IP限速", "data": None}, 429)
# update request time
config.updateRequestTime('global')
@ -95,27 +103,32 @@ async def handle_before_request(app, handler):
if (config.read_config("security.allowed_host.enable")):
if request.host.split(":")[0] not in config.read_config("security.allowed_host.list"):
if config.read_config("security.allowed_host.blacklist.enable"):
config.ban_ip(request.remote_addr, int(config.read_config("security.allowed_host.blacklist.length")))
config.ban_ip(request.remote_addr, int(
config.read_config("security.allowed_host.blacklist.length")))
return handleResult({'code': 6, 'msg': '未找到您所请求的资源', 'data': None}, 404)
resp = await handler(request)
if (isinstance(resp, (str, list, dict))):
resp = handleResult(resp)
elif (isinstance(resp, tuple) and len(resp) == 2): # flask like response
elif (isinstance(resp, tuple) and len(resp) == 2): # flask like response
body, status = resp
if (isinstance(body, (str, list, dict))):
resp = handleResult(body, status)
else:
resp = Response(body = str(body), content_type='text/plain', status = status)
resp = Response(
body=str(body), content_type='text/plain', status=status)
elif (not isinstance(resp, (Response, FileResponse, StreamResponse))):
resp = Response(body = str(resp), content_type='text/plain', status = 200)
aiologger.info(f'{request.remote_addr + ("" if (request.remote == request.remote_addr) else f"|proxy@{request.remote}")} - {request.method} "{request.path}", {resp.status}')
resp = Response(
body=str(resp), content_type='text/plain', status=200)
aiologger.info(
f'{request.remote_addr + ("" if (request.remote == request.remote_addr) else f"|proxy@{request.remote}")} - {request.method} "{request.path}", {resp.status}')
return resp
except:
except:
logger.error(traceback.format_exc())
return {"code": 4, "msg": "内部服务器错误", "data": None}
return handle_request
async def main(request):
return handleResult({"code": 0, "msg": "success", "data": None})
@ -136,7 +149,7 @@ async def handle(request):
if (config.read_config('security.lxm_ban.enable')):
config.ban_ip(request.remote_addr)
return handleResult({"code": 1, "msg": "lxm请求头验证失败", "data": None}, 403)
try:
query = dict(request.query)
if (method in dir(modules)):
@ -147,14 +160,17 @@ async def handle(request):
logger.error(traceback.format_exc())
return handleResult({'code': 4, 'msg': '内部服务器错误', 'data': None}, 500)
async def handle_404(request):
return handleResult({'code': 6, 'msg': '未找到您所请求的资源', 'data': None}, 404)
async def handle_local(request):
try:
query = dict(request.query)
data = query.get('q')
data = createBase64Decode(data.replace('-', '+').replace('_', '/'))
data = utils.createBase64Decode(
data.replace('-', '+').replace('_', '/'))
data = json.loads(data)
t = request.match_info.get('type')
data['t'] = t
@ -208,28 +224,30 @@ if (config.read_config('common.allow_download_script')):
# 404
app.router.add_route('*', '/{tail:.*}', handle_404)
async def run_app():
async def run_app_host(host):
retries = 0
while True:
if (retries > 4):
logger.warning("重试次数已达上限,但仍有部分端口未能完成监听,已自动进行忽略")
return
break
try:
host = config.read_config('common.host')
ports = [int(port) for port in config.read_config('common.ports')]
ssl_ports = [int(port) for port in config.read_config('common.ssl_info.ssl_ports')]
ports = [int(port)
for port in config.read_config('common.ports')]
ssl_ports = [int(port) for port in config.read_config(
'common.ssl_info.ssl_ports')]
final_ssl_ports = []
final_ports = []
for p in ports:
if (p not in ssl_ports and p not in variable.running_ports):
if (p not in ssl_ports and f'{host}_{p}' not in variable.running_ports):
final_ports.append(p)
else:
if (p not in variable.running_ports):
final_ssl_ports.append(p)
# 读取证书和私钥路径
cert_path = config.read_config('common.ssl_info.path.cert')
privkey_path = config.read_config('common.ssl_info.path.privkey')
privkey_path = config.read_config(
'common.ssl_info.path.privkey')
# 创建 HTTP AppRunner
http_runner = aiohttp.web.AppRunner(app)
@ -238,16 +256,21 @@ async def run_app():
# 启动 HTTP 端口监听
for port in final_ports:
if (port not in variable.running_ports):
http_site = aiohttp.web.TCPSite(http_runner, host, port)
http_site = aiohttp.web.TCPSite(
http_runner, host, port)
await http_site.start()
variable.running_ports.append(port)
logger.info(f"监听 -> http://{host}:{port}")
variable.running_ports.append(f'{host}_{port}')
logger.info(f"""监听 -> http://{
host if (':' not in host)
else '[' + host + ']'
}:{port}""")
if (config.read_config("common.ssl_info.enable") and final_ssl_ports != []):
if (os.path.exists(cert_path) and os.path.exists(privkey_path)):
import ssl
# 创建 SSL 上下文,加载配置文件中指定的证书和私钥
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_context = ssl.create_default_context(
ssl.Purpose.CLIENT_AUTH)
ssl_context.load_cert_chain(cert_path, privkey_path)
# 创建 HTTPS AppRunner
@ -257,12 +280,16 @@ async def run_app():
# 启动 HTTPS 端口监听
for port in ssl_ports:
if (port not in variable.running_ports):
https_site = aiohttp.web.TCPSite(https_runner, host, port, ssl_context=ssl_context)
https_site = aiohttp.web.TCPSite(
https_runner, host, port, ssl_context=ssl_context)
await https_site.start()
variable.running_ports.append(port)
logger.info(f"监听 -> https://{host}:{port}")
return
variable.running_ports.append(f'{host}_{port}')
logger.info(f"""监听 -> http://{
host if (':' not in host)
else '[' + host + ']'
}:{port}""")
logger.debug(f"HOST({host}) 已完成监听")
break
except OSError as e:
if (str(e).startswith("[Errno 98]") or str(e).startswith('[Errno 10048]')):
logger.error("端口已被占用,请检查\n" + str(e))
@ -271,7 +298,12 @@ async def run_app():
logger.info('重新尝试启动...')
retries += 1
else:
raise
logger.error("未知错误,请检查\n" + traceback.format_exc())
async def run_app():
for host in config.read_config('common.hosts'):
await run_app_host(host)
async def initMain():
@ -294,7 +326,7 @@ async def initMain():
logger.info('wating for sessions to complete...')
if variable.aioSession:
await variable.aioSession.close()
variable.running = False
logger.info("Server stopped")
@ -305,5 +337,20 @@ if __name__ == "__main__":
except KeyboardInterrupt:
pass
except:
logger.error('初始化出错,请检查日志')
logger.error(traceback.format_exc())
logger.critical('初始化出错,请检查日志')
logger.critical(traceback.format_exc())
with open('dumprecord_{}.txt'.format(int(time.time())), 'w', encoding='utf-8') as f:
f.write(traceback.format_exc())
e = '\n\nGlobal variable object:\n\n'
for k in dir(variable):
e += (k + ' = ' + str(getattr(variable, k)) + '\n') if (not k.startswith('_')) else ''
f.write(e)
e = '\n\nsys.modules:\n\n'
for k in sys.modules:
e += (k + ' = ' + str(sys.modules[k]) + '\n') if (not k.startswith('_')) else ''
f.write(e)
logger.critical('dumprecord_{}.txt 已保存至当前目录'.format(int(time.time())))
finally:
for f in variable.log_files:
if (f and isinstance(f, TextIOWrapper)):
f.close()