365 lines
11 KiB
Python
365 lines
11 KiB
Python
#!/usr/bin/env python3
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import json
|
||
import os
|
||
import subprocess
|
||
import sys
|
||
import traceback
|
||
import urllib.error
|
||
import urllib.request
|
||
from pathlib import Path
|
||
from urllib.parse import urlparse
|
||
|
||
sys.stderr = sys.stdout
|
||
|
||
DEFAULT_PROMPT = "请用中文输出,分成三部分:1. 详细描述视频内容;2. 总结核心信息;3. 给出对视频的理解。"
|
||
DEFAULT_FPS = 2
|
||
DEFAULT_MAX_TOKENS = 800
|
||
|
||
|
||
def _skill_root() -> Path:
|
||
return Path(__file__).resolve().parent.parent
|
||
|
||
|
||
def _skill_venv_python() -> Path:
|
||
venv_dir = _skill_root() / ".venv"
|
||
if sys.platform == "win32":
|
||
return venv_dir / "Scripts" / "python.exe"
|
||
return venv_dir / "bin" / "python"
|
||
|
||
|
||
def _get_python_executable() -> str:
|
||
if sys.executable:
|
||
return sys.executable
|
||
import shutil
|
||
for candidate in ("python3", "python"):
|
||
found = shutil.which(candidate)
|
||
if found:
|
||
return found
|
||
raise RuntimeError("无法找到 Python 解释器路径")
|
||
|
||
|
||
def _run_bootstrap() -> None:
|
||
bootstrap = Path(__file__).resolve().parent / "bootstrap.py"
|
||
result = subprocess.run([_get_python_executable(), str(bootstrap)])
|
||
if result.returncode != 0:
|
||
raise SystemExit(result.returncode)
|
||
|
||
|
||
def _ensure_skill_venv_python() -> None:
|
||
venv_python = _skill_venv_python()
|
||
if not venv_python.is_file():
|
||
_run_bootstrap()
|
||
venv_python = _skill_venv_python()
|
||
if not venv_python.is_file():
|
||
sys.stdout.write("bootstrap 后仍未找到虚拟环境\n")
|
||
raise SystemExit(1)
|
||
|
||
venv_dir = _skill_root() / ".venv"
|
||
if Path(sys.prefix) == venv_dir.resolve():
|
||
return
|
||
|
||
os.execv(str(venv_python), [str(venv_python), str(Path(__file__).resolve()), *sys.argv[1:]])
|
||
|
||
|
||
_ensure_skill_venv_python()
|
||
|
||
try:
|
||
import pymysql # type: ignore # noqa: E402
|
||
except ModuleNotFoundError:
|
||
_run_bootstrap()
|
||
_py = _get_python_executable()
|
||
os.execv(_py, [_py, str(Path(__file__).resolve()), *sys.argv[1:]])
|
||
|
||
|
||
def _mysql_connect():
|
||
host = os.environ.get("MYSQL_HOST", "127.0.0.1")
|
||
port = int(os.environ.get("MYSQL_PORT", "3306"))
|
||
user = os.environ.get("MYSQL_USER", "root")
|
||
password = os.environ.get("MYSQL_PASSWORD", "")
|
||
database = os.environ.get("ROBOT_CODE", "")
|
||
if not database:
|
||
raise RuntimeError("环境变量 ROBOT_CODE 未配置")
|
||
|
||
return pymysql.connect(
|
||
host=host,
|
||
port=port,
|
||
user=user,
|
||
password=password,
|
||
database=database,
|
||
charset="utf8mb4",
|
||
connect_timeout=10,
|
||
read_timeout=30,
|
||
)
|
||
|
||
|
||
def _query_one(conn, sql: str, params: tuple = ()) -> dict | None:
|
||
cur = conn.cursor()
|
||
cur.execute(sql, params)
|
||
columns = [desc[0] for desc in cur.description] if cur.description else []
|
||
row = cur.fetchone()
|
||
cur.close()
|
||
if row is None:
|
||
return None
|
||
return dict(zip(columns, row))
|
||
|
||
|
||
def _table_has_column(conn, table_name: str, column_name: str) -> bool:
|
||
sql = (
|
||
"SELECT 1 FROM information_schema.columns "
|
||
"WHERE table_schema = %s AND table_name = %s AND column_name = %s LIMIT 1"
|
||
)
|
||
database_name = conn.db
|
||
if isinstance(database_name, (bytes, bytearray)):
|
||
database_name = database_name.decode("utf-8")
|
||
cur = conn.cursor()
|
||
cur.execute(sql, (database_name, table_name, column_name))
|
||
row = cur.fetchone()
|
||
cur.close()
|
||
return row is not None
|
||
|
||
|
||
def _decode_settings(raw: object) -> dict:
|
||
if not raw:
|
||
return {}
|
||
if isinstance(raw, (bytes, bytearray)):
|
||
raw = raw.decode("utf-8")
|
||
if isinstance(raw, str) and raw.strip():
|
||
return json.loads(raw)
|
||
return {}
|
||
|
||
|
||
def _extract_model(record: dict | None, settings_json: dict) -> str:
|
||
if record:
|
||
model = record.get("image_recognition_model")
|
||
if isinstance(model, (bytes, bytearray)):
|
||
model = model.decode("utf-8")
|
||
if isinstance(model, str) and model.strip():
|
||
return model.strip()
|
||
|
||
for key in ("image_recognition_model", "imageRecognitionModel"):
|
||
value = settings_json.get(key)
|
||
if isinstance(value, str) and value.strip():
|
||
return value.strip()
|
||
|
||
return ""
|
||
|
||
|
||
def load_understanding_settings(conn, from_wx_id: str) -> tuple[bool, str]:
|
||
global_has_model = _table_has_column(conn, "global_settings", "image_recognition_model")
|
||
chatroom_has_model = _table_has_column(conn, "chat_room_settings", "image_recognition_model")
|
||
friend_has_model = _table_has_column(conn, "friend_settings", "image_recognition_model")
|
||
|
||
global_fields = "image_ai_enabled, image_ai_settings"
|
||
if global_has_model:
|
||
global_fields += ", image_recognition_model"
|
||
global_record = _query_one(conn, f"SELECT {global_fields} FROM global_settings LIMIT 1")
|
||
|
||
enabled = False
|
||
settings_json: dict = {}
|
||
model = ""
|
||
if global_record:
|
||
if global_record.get("image_ai_enabled") is not None:
|
||
enabled = bool(global_record["image_ai_enabled"])
|
||
settings_json = _decode_settings(global_record.get("image_ai_settings"))
|
||
model = _extract_model(global_record, settings_json)
|
||
|
||
if from_wx_id.endswith("@chatroom"):
|
||
override_fields = "image_ai_enabled, image_ai_settings"
|
||
if chatroom_has_model:
|
||
override_fields += ", image_recognition_model"
|
||
override = _query_one(
|
||
conn,
|
||
f"SELECT {override_fields} FROM chat_room_settings WHERE chat_room_id = %s LIMIT 1",
|
||
(from_wx_id,),
|
||
)
|
||
else:
|
||
override_fields = "image_ai_enabled, image_ai_settings"
|
||
if friend_has_model:
|
||
override_fields += ", image_recognition_model"
|
||
override = _query_one(
|
||
conn,
|
||
f"SELECT {override_fields} FROM friend_settings WHERE wechat_id = %s LIMIT 1",
|
||
(from_wx_id,),
|
||
)
|
||
|
||
if override:
|
||
if override.get("image_ai_enabled") is not None:
|
||
enabled = bool(override["image_ai_enabled"])
|
||
override_settings = _decode_settings(override.get("image_ai_settings"))
|
||
if override_settings:
|
||
settings_json = override_settings
|
||
override_model = _extract_model(override, settings_json)
|
||
if override_model:
|
||
model = override_model
|
||
|
||
return enabled, model
|
||
|
||
|
||
def _http_post_json(url: str, body: dict, headers: dict, timeout: int = 300) -> dict:
|
||
data = json.dumps(body).encode("utf-8")
|
||
req = urllib.request.Request(url, data=data, headers=headers, method="POST")
|
||
try:
|
||
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
||
return json.loads(resp.read().decode("utf-8"))
|
||
except urllib.error.HTTPError as exc:
|
||
error_body = exc.read().decode("utf-8", errors="replace")
|
||
raise RuntimeError(f"HTTP {exc.code}: {error_body}") from exc
|
||
except urllib.error.URLError as exc:
|
||
raise RuntimeError(str(exc)) from exc
|
||
|
||
|
||
def _extract_response_text(payload: dict) -> str:
|
||
choices = payload.get("choices", [])
|
||
if not choices:
|
||
return ""
|
||
|
||
message = choices[0].get("message", {})
|
||
content = message.get("content", "")
|
||
if isinstance(content, str):
|
||
return content.strip()
|
||
if isinstance(content, list):
|
||
texts: list[str] = []
|
||
for item in content:
|
||
if not isinstance(item, dict):
|
||
continue
|
||
if item.get("type") == "text" and isinstance(item.get("text"), str):
|
||
texts.append(item["text"].strip())
|
||
return "\n".join(text for text in texts if text)
|
||
return ""
|
||
|
||
|
||
def analyze_video(video_url: str, prompt: str, model: str, fps: int, max_tokens: int) -> str:
|
||
api_key = os.environ.get("ARK_API_KEY", "").strip()
|
||
if not api_key:
|
||
raise RuntimeError("环境变量 ARK_API_KEY 未配置")
|
||
if not model:
|
||
raise RuntimeError("数据库中未配置 image_recognition_model")
|
||
|
||
body = {
|
||
"model": model,
|
||
"messages": [
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "video_url", "video_url": {"url": video_url}, "fps": str(fps)},
|
||
{"type": "text", "text": prompt},
|
||
],
|
||
}
|
||
],
|
||
"max_tokens": max_tokens,
|
||
}
|
||
response = _http_post_json(
|
||
"https://ark.cn-beijing.volces.com/api/v3/chat/completions",
|
||
body,
|
||
{"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
|
||
timeout=300,
|
||
)
|
||
text = _extract_response_text(response)
|
||
if not text:
|
||
raise RuntimeError("视频理解接口未返回文本内容")
|
||
return text
|
||
|
||
|
||
def _validate_video_url(value: str) -> str:
|
||
parsed = urlparse(value)
|
||
if parsed.scheme != "https" or not parsed.netloc:
|
||
raise ValueError("video_url 必须是 https 链接")
|
||
return value
|
||
|
||
|
||
def _parse_cli_params(argv: list[str]) -> dict:
|
||
parser = argparse.ArgumentParser(add_help=False)
|
||
parser.add_argument("--video_url", default="")
|
||
parser.add_argument("--prompt", default=DEFAULT_PROMPT)
|
||
parser.add_argument("--fps", type=int, default=DEFAULT_FPS)
|
||
parser.add_argument("--max_tokens", type=int, default=DEFAULT_MAX_TOKENS)
|
||
|
||
namespace, unknown = parser.parse_known_args(argv)
|
||
if unknown:
|
||
raise ValueError(f"存在不支持的参数: {' '.join(unknown)}")
|
||
if namespace.fps <= 0:
|
||
raise ValueError("fps 必须大于 0")
|
||
if namespace.max_tokens <= 0:
|
||
raise ValueError("max_tokens 必须大于 0")
|
||
|
||
return {
|
||
"video_url": namespace.video_url,
|
||
"prompt": namespace.prompt,
|
||
"fps": namespace.fps,
|
||
"max_tokens": namespace.max_tokens,
|
||
}
|
||
|
||
|
||
def main() -> int:
|
||
if len(sys.argv) < 2:
|
||
sys.stdout.write("缺少输入参数\n")
|
||
return 1
|
||
|
||
try:
|
||
params = _parse_cli_params(sys.argv[1:])
|
||
except ValueError as exc:
|
||
sys.stdout.write(f"参数格式错误: {exc}\n")
|
||
return 1
|
||
|
||
video_url = params.get("video_url", "").strip()
|
||
if not video_url:
|
||
sys.stdout.write("缺少视频链接\n")
|
||
return 1
|
||
try:
|
||
_validate_video_url(video_url)
|
||
except ValueError as exc:
|
||
sys.stdout.write(f"参数格式错误: {exc}\n")
|
||
return 1
|
||
|
||
prompt = params.get("prompt", "").strip() or DEFAULT_PROMPT
|
||
fps = int(params.get("fps", DEFAULT_FPS))
|
||
max_tokens = int(params.get("max_tokens", DEFAULT_MAX_TOKENS))
|
||
|
||
from_wx_id = os.environ.get("ROBOT_FROM_WX_ID", "").strip()
|
||
if not from_wx_id:
|
||
sys.stdout.write("环境变量 ROBOT_FROM_WX_ID 未配置\n")
|
||
return 1
|
||
|
||
try:
|
||
conn = _mysql_connect()
|
||
except Exception as exc:
|
||
sys.stdout.write(f"数据库连接失败: {exc}\n")
|
||
return 1
|
||
|
||
try:
|
||
enabled, model = load_understanding_settings(conn, from_wx_id)
|
||
except Exception as exc:
|
||
sys.stdout.write(f"加载视频理解配置失败: {exc}\n")
|
||
return 1
|
||
finally:
|
||
try:
|
||
conn.close()
|
||
except Exception:
|
||
pass
|
||
|
||
if not enabled:
|
||
sys.stdout.write("AI 图像识别未开启\n")
|
||
return 0
|
||
|
||
try:
|
||
content = analyze_video(video_url, prompt, model, fps, max_tokens)
|
||
except Exception as exc:
|
||
sys.stdout.write(f"调用视频理解接口失败: {exc}\n")
|
||
return 1
|
||
|
||
sys.stdout.write(f"{content}\n")
|
||
return 0
|
||
|
||
|
||
if __name__ == "__main__":
|
||
try:
|
||
raise SystemExit(main())
|
||
except SystemExit:
|
||
raise
|
||
except Exception:
|
||
traceback.print_exc(file=sys.stdout)
|
||
raise SystemExit(1) |