feat: 更换flask框架为aiohttp

This commit is contained in:
helloplhm-qwq 2023-12-02 18:12:14 +08:00
parent f4e295763e
commit 842bf64ef6
No known key found for this signature in database
GPG Key ID: 6BE1B64B905567C7
6 changed files with 185 additions and 73 deletions

View File

@ -48,7 +48,7 @@ sourceExpirationTime = {
} }
async def SongURL(source, songId, quality): async def handle_api_request(command, source, songId, quality):
if (source == "kg"): if (source == "kg"):
songId = songId.lower() songId = songId.lower()
try: try:
@ -63,11 +63,11 @@ async def SongURL(source, songId, quality):
except: except:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
try: try:
func = require('apis.' + source).url func = require('apis.' + source + '.' + command)
except: except:
return { return {
'code': 1, 'code': 1,
'msg': '未知的源: ' + source, 'msg': '未知的源或命令',
'data': None, 'data': None,
} }
try: try:

View File

@ -192,6 +192,7 @@ def handle_default_config():
logger.info('首次启动或配置文件被删除,已创建默认配置文件') logger.info('首次启动或配置文件被删除,已创建默认配置文件')
logger.info( logger.info(
f'\n建议您到{variable.workdir + os.path.sep}config.json修改配置后重新启动服务器') f'\n建议您到{variable.workdir + os.path.sep}config.json修改配置后重新启动服务器')
return default
class ConfigReadException(Exception): class ConfigReadException(Exception):

View File

@ -30,8 +30,8 @@ def highlight_error(error):
# 返回语法高亮后的堆栈跟踪字符串 # 返回语法高亮后的堆栈跟踪字符串
return str(highlighted_traceback) return str(highlighted_traceback)
class flaskLogHelper(logging.Handler): class LogHelper(logging.Handler):
# werkzeug日志转接器 # 日志转接器
def __init__(self, custom_logger): def __init__(self, custom_logger):
super().__init__() super().__init__()
self.custom_logger = custom_logger self.custom_logger = custom_logger

View File

@ -18,7 +18,8 @@ import ujson as json
import xmltodict import xmltodict
from urllib.parse import quote from urllib.parse import quote
from hashlib import md5 as _md5 from hashlib import md5 as _md5
from flask import Response from aiohttp.web import Response
# from flask import Response
def to_base64(data_bytes): def to_base64(data_bytes):
encoded_data = base64.b64encode(data_bytes) encoded_data = base64.b64encode(data_bytes)
@ -91,8 +92,8 @@ def unique_list(list_in):
[unique_list.append(x) for x in list_in if x not in unique_list] [unique_list.append(x) for x in list_in if x not in unique_list]
return unique_list return unique_list
def format_dict_json(dic): def handle_response(dic, status = 200):
return Response(json.dumps(dic, indent=2, ensure_ascii=False), mimetype = "application/json") return Response(body = json.dumps(dic, indent=2, ensure_ascii=False), content_type='application/json', status = status)
def encodeURIComponent(component): def encodeURIComponent(component):
return quote(component) return quote(component)

145
main.py
View File

@ -8,88 +8,103 @@
# - license: MIT - # - license: MIT -
# ---------------------------------------- # ----------------------------------------
# This file is part of the "lx-music-api-server" project. # This file is part of the "lx-music-api-server" project.
# Do not edit except you know what you are doing.
# flask from aiohttp import web
from flask import Flask, request
# create flask app
app = Flask("LXMusicTestAPI")
# redirect the default flask logging to custom
import logging
from common import config from common import config
from common import log
flask_logger = log.log('flask')
logging.getLogger('werkzeug').addHandler(log.flaskLogHelper(flask_logger))
logger = log.log("main")
from common import utils
from common import lxsecurity from common import lxsecurity
from common import utils
from common import log
from common import Httpx from common import Httpx
from apis import SongURL from apis import handle_api_request
import traceback import traceback
import time import time
logger = log.log("main")
aiologger = log.log('aiohttp_web')
Httpx.checkcn() Httpx.checkcn()
@app.route('/') # check request info before start
def index(): async def handle_before_request(app, handler):
return utils.format_dict_json({"code": 0, "msg": "success", "data": None}), 200 async def handle_request(request):
# nginx proxy header
if (request.headers.get("X-Real-IP")):
request.remote = request.headers.get("X-Real-IP")
# check ip
if (config.check_ip_banned(request.remote)):
return utils.handle_response({"code": 1, "msg": "您的IP已被封禁", "data": None}, 403)
# check global rate limit
if (
(time.time() - config.getRequestTime('global'))
<
(config.read_config("security.rate_limit.global"))
):
return utils.handle_response({"code": 5, "msg": "全局限速", "data": None}, 429)
if (
(time.time() - config.getRequestTime(request.remote))
<
(config.read_config("security.rate_limit.ip"))
):
return utils.handle_response({"code": 5, "msg": "IP限速", "data": None}, 429)
# update request time
config.updateRequestTime('global')
config.updateRequestTime(request.remote)
# check host
if (config.read_config("security.allowed_host.enable")):
if request.remote_host.split(":")[0] not in config.read_config("security.allowed_host.list"):
if config.read_config("security.allowed_host.blacklist.enable"):
config.ban_ip(request.remote, int(config.read_config("security.allowed_host.blacklist.length")))
return utils.handle_response({'code': 6, 'msg': '未找到您所请求的资源', 'data': None}, 404)
try:
resp = await handler(request)
aiologger.info(f'{request.remote} - {request.method} "{request.path}", {resp.status}')
return resp
except web.HTTPException as ex:
if ex.status == 500: # 捕获500错误
return utils.handle_response({"code": 4, "msg": "内部服务器错误", "data": None}, 500)
else:
logger.error(traceback.format_exc())
return utils.handle_response({'code': 6, 'msg': '未找到您所请求的资源', 'data': None}, 404)
return handle_request
@app.route('/<method>/<source>/<songId>/<quality>') async def main(request):
async def handle(method, source, songId, quality): return utils.handle_response({"code": 0, "msg": "success", "data": None})
async def handle(request):
method = request.match_info.get('method')
source = request.match_info.get('source')
songId = request.match_info.get('songId')
quality = request.match_info.get('quality')
if (config.read_config("security.key.enable") and request.host.split(':')[0] not in config.read_config('security.whitelist_host')): if (config.read_config("security.key.enable") and request.host.split(':')[0] not in config.read_config('security.whitelist_host')):
if (request.headers.get("X-Request-Key")) != config.read_config("security.key.value"): if (request.headers.get("X-Request-Key")) != config.read_config("security.key.value"):
if (config.read_config("security.key.ban")): if (config.read_config("security.key.ban")):
config.ban_ip(request.remote_addr) config.ban_ip(request.remote)
return utils.format_dict_json({"code": 1, "msg": "key验证失败", "data": None}), 403 return utils.handle_response({"code": 1, "msg": "key验证失败", "data": None}, 403)
if (config.read_config('security.check_lxm.enable') and request.host.split(':')[0] not in config.read_config('security.whitelist_host')): if (config.read_config('security.check_lxm.enable') and request.host.split(':')[0] not in config.read_config('security.whitelist_host')):
lxm = request.headers.get('lxm') lxm = request.headers.get('lxm')
if (not lxsecurity.checklxmheader(lxm, request.url)): if (not lxsecurity.checklxmheader(lxm, request.url)):
if (config.read_config('security.lxm_ban.enable')): if (config.read_config('security.lxm_ban.enable')):
config.ban_ip(request.remote_addr) config.ban_ip(request.remote)
return utils.format_dict_json({"code": 1, "msg": "lxm请求头验证失败", "data": None}), 403 return utils.handle_response({"code": 1, "msg": "lxm请求头验证失败", "data": None}, 403)
if method == 'url': try:
try: return utils.handle_response(await handle_api_request(method, source, songId, quality))
return utils.format_dict_json(await SongURL(source, songId, quality)) except Exception as e:
except Exception as e: logger.error(traceback.format_exc())
logger.error(traceback.format_exc()) return utils.handle_response({'code': 4, 'msg': '内部服务器错误', 'data': None}, 500)
return utils.format_dict_json({'code': 4, 'msg': '内部服务器错误', 'data': None}), 500
else:
return utils.format_dict_json({'code': 6, 'msg': '未知的请求类型: ' + method, 'data': None}), 400
@app.errorhandler(500) async def handle_404(request):
def _500(_): return utils.handle_response({'code': 6, 'msg': '未找到您所请求的资源', 'data': None}, 404)
return utils.format_dict_json({'code': 4, 'msg': '内部服务器错误', 'data': None}), 500
@app.errorhandler(404) app = web.Application(middlewares=[handle_before_request])
def _404(_): # mainpage
return utils.format_dict_json({'code': 6, 'msg': '未找到您所请求的资源', 'data': None}), 404 app.router.add_get('/', main)
@app.before_request # api
def check(): app.router.add_get('/{method}/{source}/{songId}/{quality}', handle)
# nginx proxy header
if (request.headers.get("X-Real-IP")): # 404
request.remote_addr = request.headers.get("X-Real-IP") app.router.add_route('*', '/{tail:.*}', handle_404)
# check ip
if (config.check_ip_banned(request.remote_addr)): web.run_app(app, host=config.read_config('common.host'), port=config.read_config('common.port'))
return utils.format_dict_json({"code": 1, "msg": "您的IP已被封禁", "data": None}), 403
# check global rate limit
if ((time.time() - config.getRequestTime('global')) <= (config.read_config("security.rate_limit.global"))):
return utils.format_dict_json({"code": 5, "msg": "全局限速", "data": None}), 429
if ((time.time() - config.getRequestTime(request.remote_addr)) <= (config.read_config("security.rate_limit.ip"))):
return utils.format_dict_json({"code": 5, "msg": "IP限速", "data": None}), 429
# update request time
config.updateRequestTime('global')
config.updateRequestTime(request.remote_addr)
# check host
if (config.read_config("security.allowed_host.enable")):
if request.remote_host.split(":")[0] not in config.read_config("security.allowed_host.list"):
if config.read_config("security.allowed_host.blacklist.enable"):
config.ban_ip(request.remote_addr, int(config.read_config("security.allowed_host.blacklist.length")))
return utils.format_dict_json({'code': 6, 'msg': '未找到您所请求的资源', 'data': None}), 404
# run
app.run(host=config.read_config('common.host'), port=config.read_config('common.port'))

95
orimain.py Normal file
View File

@ -0,0 +1,95 @@
#!/usr/bin/env python3
# ----------------------------------------
# - mode: python -
# - author: helloplhm-qwq -
# - name: main.py -
# - project: lx-music-api-server -
# - license: MIT -
# ----------------------------------------
# This file is part of the "lx-music-api-server" project.
# Do not edit except you know what you are doing.
# flask
from flask import Flask, request
# create flask app
app = Flask("LXMusicTestAPI")
# redirect the default flask logging to custom
import logging
from common import config
from common import log
flask_logger = log.log('flask')
logging.getLogger('werkzeug').addHandler(log.flaskLogHelper(flask_logger))
logger = log.log("main")
from common import utils
from common import lxsecurity
from common import Httpx
from apis import SongURL
import traceback
import time
Httpx.checkcn()
@app.route('/')
def index():
return utils.format_dict_json({"code": 0, "msg": "success", "data": None}), 200
@app.route('/<method>/<source>/<songId>/<quality>')
async def handle(method, source, songId, quality):
if (config.read_config("security.key.enable") and request.host.split(':')[0] not in config.read_config('security.whitelist_host')):
if (request.headers.get("X-Request-Key")) != config.read_config("security.key.value"):
if (config.read_config("security.key.ban")):
config.ban_ip(request.remote_addr)
return utils.format_dict_json({"code": 1, "msg": "key验证失败", "data": None}), 403
if (config.read_config('security.check_lxm.enable') and request.host.split(':')[0] not in config.read_config('security.whitelist_host')):
lxm = request.headers.get('lxm')
if (not lxsecurity.checklxmheader(lxm, request.url)):
if (config.read_config('security.lxm_ban.enable')):
config.ban_ip(request.remote_addr)
return utils.format_dict_json({"code": 1, "msg": "lxm请求头验证失败", "data": None}), 403
if method == 'url':
try:
return utils.format_dict_json(await SongURL(source, songId, quality))
except Exception as e:
logger.error(traceback.format_exc())
return utils.format_dict_json({'code': 4, 'msg': '内部服务器错误', 'data': None}), 500
else:
return utils.format_dict_json({'code': 6, 'msg': '未知的请求类型: ' + method, 'data': None}), 400
@app.errorhandler(500)
def _500(_):
return utils.format_dict_json({'code': 4, 'msg': '内部服务器错误', 'data': None}), 500
@app.errorhandler(404)
def _404(_):
return utils.format_dict_json({'code': 6, 'msg': '未找到您所请求的资源', 'data': None}), 404
@app.before_request
def check():
# nginx proxy header
if (request.headers.get("X-Real-IP")):
request.remote_addr = request.headers.get("X-Real-IP")
# check ip
if (config.check_ip_banned(request.remote_addr)):
return utils.format_dict_json({"code": 1, "msg": "您的IP已被封禁", "data": None}), 403
# check global rate limit
if ((time.time() - config.getRequestTime('global')) <= (config.read_config("security.rate_limit.global"))):
return utils.format_dict_json({"code": 5, "msg": "全局限速", "data": None}), 429
if ((time.time() - config.getRequestTime(request.remote_addr)) <= (config.read_config("security.rate_limit.ip"))):
return utils.format_dict_json({"code": 5, "msg": "IP限速", "data": None}), 429
# update request time
config.updateRequestTime('global')
config.updateRequestTime(request.remote_addr)
# check host
if (config.read_config("security.allowed_host.enable")):
if request.remote_host.split(":")[0] not in config.read_config("security.allowed_host.list"):
if config.read_config("security.allowed_host.blacklist.enable"):
config.ban_ip(request.remote_addr, int(config.read_config("security.allowed_host.blacklist.length")))
return utils.format_dict_json({'code': 6, 'msg': '未找到您所请求的资源', 'data': None}), 404
# run
app.run(host=config.read_config('common.host'), port=config.read_config('common.port'))