mirror of
https://github.com/MeoProject/lx-music-api-server.git
synced 2025-07-06 22:42:14 +08:00
feat: 更新太大了不想总结自己去看提交记录吧(已知把配置文件换成了yaml
This commit is contained in:
149
main.py
149
main.py
@ -1,22 +1,33 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# ----------------------------------------
|
||||
# - mode: python -
|
||||
# - author: helloplhm-qwq -
|
||||
# - name: main.py -
|
||||
# - project: lx-music-api-server -
|
||||
# - license: MIT -
|
||||
# - mode: python -
|
||||
# - author: helloplhm-qwq -
|
||||
# - name: main.py -
|
||||
# - project: lx-music-api-server -
|
||||
# - license: MIT -
|
||||
# ----------------------------------------
|
||||
# This file is part of the "lx-music-api-server" project.
|
||||
|
||||
import time
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import traceback
|
||||
import threading
|
||||
import ujson as json
|
||||
from aiohttp.web import Response, FileResponse, StreamResponse
|
||||
from io import TextIOWrapper
|
||||
import sys
|
||||
|
||||
from common.utils import createBase64Decode
|
||||
import os
|
||||
|
||||
if ((sys.version_info.major == 3 and sys.version_info.minor < 6) or sys.version_info.major == 2):
|
||||
print('Python版本过低,请使用Python 3.6+ ')
|
||||
sys.exit(1)
|
||||
|
||||
# fix: module not found: common/modules
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from common import utils
|
||||
from common import config, localMusic
|
||||
from common import lxsecurity
|
||||
from common import log
|
||||
@ -24,24 +35,17 @@ from common import Httpx
|
||||
from common import variable
|
||||
from common import scheduler
|
||||
from common import lx_script
|
||||
from aiohttp.web import Response, FileResponse, StreamResponse
|
||||
import ujson as json
|
||||
import threading
|
||||
import traceback
|
||||
import modules
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import time
|
||||
import os
|
||||
|
||||
def handleResult(dic, status = 200) -> Response:
|
||||
def handleResult(dic, status=200) -> Response:
|
||||
if (not isinstance(dic, dict)):
|
||||
dic = {
|
||||
'code': 0,
|
||||
'msg': 'success',
|
||||
'data': dic
|
||||
}
|
||||
return Response(body = json.dumps(dic, indent=2, ensure_ascii=False), content_type='application/json', status = status)
|
||||
return Response(body=json.dumps(dic, indent=2, ensure_ascii=False), content_type='application/json', status=status)
|
||||
|
||||
|
||||
logger = log.log("main")
|
||||
aiologger = log.log('aiohttp_web')
|
||||
@ -54,20 +58,24 @@ if (sys.version_info.minor < 8 and sys.version_info.major == 3):
|
||||
else:
|
||||
stopEvent = asyncio.exceptions.CancelledError
|
||||
|
||||
|
||||
def start_checkcn_thread() -> None:
|
||||
threading.Thread(target=Httpx.checkcn).start()
|
||||
|
||||
# check request info before start
|
||||
|
||||
|
||||
async def handle_before_request(app, handler):
|
||||
async def handle_request(request):
|
||||
try:
|
||||
if (config.read_config('common.reverse_proxy.allow_proxy')):
|
||||
if (request.headers.get(config.read_config('common.reverse_proxy.real_ip_header'))):
|
||||
# proxy header
|
||||
if (request.remote in config.read_config('common.reverse_proxy.proxy_whitelist_remote')):
|
||||
request.remote_addr = request.headers.get(config.read_config('common.reverse_proxy.real_ip_header'))
|
||||
if (config.read_config('common.reverse_proxy.allow_public_ip') and (not utils.is_private_ip(request.remote))):
|
||||
request.remote_addr = request.headers.get(
|
||||
config.read_config('common.reverse_proxy.real_ip_header'))
|
||||
else:
|
||||
return handleResult({"code": 1, "msg": "反代客户端远程地址不在反代ip白名单中", "data": None}, 403)
|
||||
return handleResult({"code": 1, "msg": "不允许的公网ip转发", "data": None}, 403)
|
||||
else:
|
||||
request.remote_addr = request.remote
|
||||
else:
|
||||
@ -80,13 +88,13 @@ async def handle_before_request(app, handler):
|
||||
(time.time() - config.getRequestTime('global'))
|
||||
<
|
||||
(config.read_config("security.rate_limit.global"))
|
||||
):
|
||||
):
|
||||
return handleResult({"code": 5, "msg": "全局限速", "data": None}, 429)
|
||||
if (
|
||||
(time.time() - config.getRequestTime(request.remote_addr))
|
||||
<
|
||||
(config.read_config("security.rate_limit.ip"))
|
||||
):
|
||||
):
|
||||
return handleResult({"code": 5, "msg": "IP限速", "data": None}, 429)
|
||||
# update request time
|
||||
config.updateRequestTime('global')
|
||||
@ -95,27 +103,32 @@ async def handle_before_request(app, handler):
|
||||
if (config.read_config("security.allowed_host.enable")):
|
||||
if request.host.split(":")[0] not in config.read_config("security.allowed_host.list"):
|
||||
if config.read_config("security.allowed_host.blacklist.enable"):
|
||||
config.ban_ip(request.remote_addr, int(config.read_config("security.allowed_host.blacklist.length")))
|
||||
config.ban_ip(request.remote_addr, int(
|
||||
config.read_config("security.allowed_host.blacklist.length")))
|
||||
return handleResult({'code': 6, 'msg': '未找到您所请求的资源', 'data': None}, 404)
|
||||
|
||||
resp = await handler(request)
|
||||
if (isinstance(resp, (str, list, dict))):
|
||||
resp = handleResult(resp)
|
||||
elif (isinstance(resp, tuple) and len(resp) == 2): # flask like response
|
||||
elif (isinstance(resp, tuple) and len(resp) == 2): # flask like response
|
||||
body, status = resp
|
||||
if (isinstance(body, (str, list, dict))):
|
||||
resp = handleResult(body, status)
|
||||
else:
|
||||
resp = Response(body = str(body), content_type='text/plain', status = status)
|
||||
resp = Response(
|
||||
body=str(body), content_type='text/plain', status=status)
|
||||
elif (not isinstance(resp, (Response, FileResponse, StreamResponse))):
|
||||
resp = Response(body = str(resp), content_type='text/plain', status = 200)
|
||||
aiologger.info(f'{request.remote_addr + ("" if (request.remote == request.remote_addr) else f"|proxy@{request.remote}")} - {request.method} "{request.path}", {resp.status}')
|
||||
resp = Response(
|
||||
body=str(resp), content_type='text/plain', status=200)
|
||||
aiologger.info(
|
||||
f'{request.remote_addr + ("" if (request.remote == request.remote_addr) else f"|proxy@{request.remote}")} - {request.method} "{request.path}", {resp.status}')
|
||||
return resp
|
||||
except:
|
||||
except:
|
||||
logger.error(traceback.format_exc())
|
||||
return {"code": 4, "msg": "内部服务器错误", "data": None}
|
||||
return handle_request
|
||||
|
||||
|
||||
async def main(request):
|
||||
return handleResult({"code": 0, "msg": "success", "data": None})
|
||||
|
||||
@ -136,7 +149,7 @@ async def handle(request):
|
||||
if (config.read_config('security.lxm_ban.enable')):
|
||||
config.ban_ip(request.remote_addr)
|
||||
return handleResult({"code": 1, "msg": "lxm请求头验证失败", "data": None}, 403)
|
||||
|
||||
|
||||
try:
|
||||
query = dict(request.query)
|
||||
if (method in dir(modules)):
|
||||
@ -147,14 +160,17 @@ async def handle(request):
|
||||
logger.error(traceback.format_exc())
|
||||
return handleResult({'code': 4, 'msg': '内部服务器错误', 'data': None}, 500)
|
||||
|
||||
|
||||
async def handle_404(request):
|
||||
return handleResult({'code': 6, 'msg': '未找到您所请求的资源', 'data': None}, 404)
|
||||
|
||||
|
||||
async def handle_local(request):
|
||||
try:
|
||||
query = dict(request.query)
|
||||
data = query.get('q')
|
||||
data = createBase64Decode(data.replace('-', '+').replace('_', '/'))
|
||||
data = utils.createBase64Decode(
|
||||
data.replace('-', '+').replace('_', '/'))
|
||||
data = json.loads(data)
|
||||
t = request.match_info.get('type')
|
||||
data['t'] = t
|
||||
@ -208,28 +224,30 @@ if (config.read_config('common.allow_download_script')):
|
||||
# 404
|
||||
app.router.add_route('*', '/{tail:.*}', handle_404)
|
||||
|
||||
async def run_app():
|
||||
|
||||
async def run_app_host(host):
|
||||
retries = 0
|
||||
while True:
|
||||
if (retries > 4):
|
||||
logger.warning("重试次数已达上限,但仍有部分端口未能完成监听,已自动进行忽略")
|
||||
return
|
||||
break
|
||||
try:
|
||||
host = config.read_config('common.host')
|
||||
ports = [int(port) for port in config.read_config('common.ports')]
|
||||
ssl_ports = [int(port) for port in config.read_config('common.ssl_info.ssl_ports')]
|
||||
|
||||
ports = [int(port)
|
||||
for port in config.read_config('common.ports')]
|
||||
ssl_ports = [int(port) for port in config.read_config(
|
||||
'common.ssl_info.ssl_ports')]
|
||||
final_ssl_ports = []
|
||||
final_ports = []
|
||||
for p in ports:
|
||||
if (p not in ssl_ports and p not in variable.running_ports):
|
||||
if (p not in ssl_ports and f'{host}_{p}' not in variable.running_ports):
|
||||
final_ports.append(p)
|
||||
else:
|
||||
if (p not in variable.running_ports):
|
||||
final_ssl_ports.append(p)
|
||||
# 读取证书和私钥路径
|
||||
cert_path = config.read_config('common.ssl_info.path.cert')
|
||||
privkey_path = config.read_config('common.ssl_info.path.privkey')
|
||||
privkey_path = config.read_config(
|
||||
'common.ssl_info.path.privkey')
|
||||
|
||||
# 创建 HTTP AppRunner
|
||||
http_runner = aiohttp.web.AppRunner(app)
|
||||
@ -238,16 +256,21 @@ async def run_app():
|
||||
# 启动 HTTP 端口监听
|
||||
for port in final_ports:
|
||||
if (port not in variable.running_ports):
|
||||
http_site = aiohttp.web.TCPSite(http_runner, host, port)
|
||||
http_site = aiohttp.web.TCPSite(
|
||||
http_runner, host, port)
|
||||
await http_site.start()
|
||||
variable.running_ports.append(port)
|
||||
logger.info(f"监听 -> http://{host}:{port}")
|
||||
variable.running_ports.append(f'{host}_{port}')
|
||||
logger.info(f"""监听 -> http://{
|
||||
host if (':' not in host)
|
||||
else '[' + host + ']'
|
||||
}:{port}""")
|
||||
|
||||
if (config.read_config("common.ssl_info.enable") and final_ssl_ports != []):
|
||||
if (os.path.exists(cert_path) and os.path.exists(privkey_path)):
|
||||
import ssl
|
||||
# 创建 SSL 上下文,加载配置文件中指定的证书和私钥
|
||||
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
||||
ssl_context = ssl.create_default_context(
|
||||
ssl.Purpose.CLIENT_AUTH)
|
||||
ssl_context.load_cert_chain(cert_path, privkey_path)
|
||||
|
||||
# 创建 HTTPS AppRunner
|
||||
@ -257,12 +280,16 @@ async def run_app():
|
||||
# 启动 HTTPS 端口监听
|
||||
for port in ssl_ports:
|
||||
if (port not in variable.running_ports):
|
||||
https_site = aiohttp.web.TCPSite(https_runner, host, port, ssl_context=ssl_context)
|
||||
https_site = aiohttp.web.TCPSite(
|
||||
https_runner, host, port, ssl_context=ssl_context)
|
||||
await https_site.start()
|
||||
variable.running_ports.append(port)
|
||||
logger.info(f"监听 -> https://{host}:{port}")
|
||||
|
||||
return
|
||||
variable.running_ports.append(f'{host}_{port}')
|
||||
logger.info(f"""监听 -> http://{
|
||||
host if (':' not in host)
|
||||
else '[' + host + ']'
|
||||
}:{port}""")
|
||||
logger.debug(f"HOST({host}) 已完成监听")
|
||||
break
|
||||
except OSError as e:
|
||||
if (str(e).startswith("[Errno 98]") or str(e).startswith('[Errno 10048]')):
|
||||
logger.error("端口已被占用,请检查\n" + str(e))
|
||||
@ -271,7 +298,12 @@ async def run_app():
|
||||
logger.info('重新尝试启动...')
|
||||
retries += 1
|
||||
else:
|
||||
raise
|
||||
logger.error("未知错误,请检查\n" + traceback.format_exc())
|
||||
|
||||
|
||||
async def run_app():
|
||||
for host in config.read_config('common.hosts'):
|
||||
await run_app_host(host)
|
||||
|
||||
|
||||
async def initMain():
|
||||
@ -294,7 +326,7 @@ async def initMain():
|
||||
logger.info('wating for sessions to complete...')
|
||||
if variable.aioSession:
|
||||
await variable.aioSession.close()
|
||||
|
||||
|
||||
variable.running = False
|
||||
logger.info("Server stopped")
|
||||
|
||||
@ -305,5 +337,20 @@ if __name__ == "__main__":
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
except:
|
||||
logger.error('初始化出错,请检查日志')
|
||||
logger.error(traceback.format_exc())
|
||||
logger.critical('初始化出错,请检查日志')
|
||||
logger.critical(traceback.format_exc())
|
||||
with open('dumprecord_{}.txt'.format(int(time.time())), 'w', encoding='utf-8') as f:
|
||||
f.write(traceback.format_exc())
|
||||
e = '\n\nGlobal variable object:\n\n'
|
||||
for k in dir(variable):
|
||||
e += (k + ' = ' + str(getattr(variable, k)) + '\n') if (not k.startswith('_')) else ''
|
||||
f.write(e)
|
||||
e = '\n\nsys.modules:\n\n'
|
||||
for k in sys.modules:
|
||||
e += (k + ' = ' + str(sys.modules[k]) + '\n') if (not k.startswith('_')) else ''
|
||||
f.write(e)
|
||||
logger.critical('dumprecord_{}.txt 已保存至当前目录'.format(int(time.time())))
|
||||
finally:
|
||||
for f in variable.log_files:
|
||||
if (f and isinstance(f, TextIOWrapper)):
|
||||
f.close()
|
||||
|
Reference in New Issue
Block a user