Astor--0/skills/doubao-video-understanding/scripts/video_understanding.py
2026-05-02 00:57:36 +08:00

365 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import os
import subprocess
import sys
import traceback
import urllib.error
import urllib.request
from pathlib import Path
from urllib.parse import urlparse
sys.stderr = sys.stdout
DEFAULT_PROMPT = "请用中文输出分成三部分1. 详细描述视频内容2. 总结核心信息3. 给出对视频的理解。"
DEFAULT_FPS = 2
DEFAULT_MAX_TOKENS = 800
def _skill_root() -> Path:
return Path(__file__).resolve().parent.parent
def _skill_venv_python() -> Path:
venv_dir = _skill_root() / ".venv"
if sys.platform == "win32":
return venv_dir / "Scripts" / "python.exe"
return venv_dir / "bin" / "python"
def _get_python_executable() -> str:
if sys.executable:
return sys.executable
import shutil
for candidate in ("python3", "python"):
found = shutil.which(candidate)
if found:
return found
raise RuntimeError("无法找到 Python 解释器路径")
def _run_bootstrap() -> None:
bootstrap = Path(__file__).resolve().parent / "bootstrap.py"
result = subprocess.run([_get_python_executable(), str(bootstrap)])
if result.returncode != 0:
raise SystemExit(result.returncode)
def _ensure_skill_venv_python() -> None:
venv_python = _skill_venv_python()
if not venv_python.is_file():
_run_bootstrap()
venv_python = _skill_venv_python()
if not venv_python.is_file():
sys.stdout.write("bootstrap 后仍未找到虚拟环境\n")
raise SystemExit(1)
venv_dir = _skill_root() / ".venv"
if Path(sys.prefix) == venv_dir.resolve():
return
os.execv(str(venv_python), [str(venv_python), str(Path(__file__).resolve()), *sys.argv[1:]])
_ensure_skill_venv_python()
try:
import pymysql # type: ignore # noqa: E402
except ModuleNotFoundError:
_run_bootstrap()
_py = _get_python_executable()
os.execv(_py, [_py, str(Path(__file__).resolve()), *sys.argv[1:]])
def _mysql_connect():
host = os.environ.get("MYSQL_HOST", "127.0.0.1")
port = int(os.environ.get("MYSQL_PORT", "3306"))
user = os.environ.get("MYSQL_USER", "root")
password = os.environ.get("MYSQL_PASSWORD", "")
database = os.environ.get("ROBOT_CODE", "")
if not database:
raise RuntimeError("环境变量 ROBOT_CODE 未配置")
return pymysql.connect(
host=host,
port=port,
user=user,
password=password,
database=database,
charset="utf8mb4",
connect_timeout=10,
read_timeout=30,
)
def _query_one(conn, sql: str, params: tuple = ()) -> dict | None:
cur = conn.cursor()
cur.execute(sql, params)
columns = [desc[0] for desc in cur.description] if cur.description else []
row = cur.fetchone()
cur.close()
if row is None:
return None
return dict(zip(columns, row))
def _table_has_column(conn, table_name: str, column_name: str) -> bool:
sql = (
"SELECT 1 FROM information_schema.columns "
"WHERE table_schema = %s AND table_name = %s AND column_name = %s LIMIT 1"
)
database_name = conn.db
if isinstance(database_name, (bytes, bytearray)):
database_name = database_name.decode("utf-8")
cur = conn.cursor()
cur.execute(sql, (database_name, table_name, column_name))
row = cur.fetchone()
cur.close()
return row is not None
def _decode_settings(raw: object) -> dict:
if not raw:
return {}
if isinstance(raw, (bytes, bytearray)):
raw = raw.decode("utf-8")
if isinstance(raw, str) and raw.strip():
return json.loads(raw)
return {}
def _extract_model(record: dict | None, settings_json: dict) -> str:
if record:
model = record.get("image_recognition_model")
if isinstance(model, (bytes, bytearray)):
model = model.decode("utf-8")
if isinstance(model, str) and model.strip():
return model.strip()
for key in ("image_recognition_model", "imageRecognitionModel"):
value = settings_json.get(key)
if isinstance(value, str) and value.strip():
return value.strip()
return ""
def load_understanding_settings(conn, from_wx_id: str) -> tuple[bool, str]:
global_has_model = _table_has_column(conn, "global_settings", "image_recognition_model")
chatroom_has_model = _table_has_column(conn, "chat_room_settings", "image_recognition_model")
friend_has_model = _table_has_column(conn, "friend_settings", "image_recognition_model")
global_fields = "image_ai_enabled, image_ai_settings"
if global_has_model:
global_fields += ", image_recognition_model"
global_record = _query_one(conn, f"SELECT {global_fields} FROM global_settings LIMIT 1")
enabled = False
settings_json: dict = {}
model = ""
if global_record:
if global_record.get("image_ai_enabled") is not None:
enabled = bool(global_record["image_ai_enabled"])
settings_json = _decode_settings(global_record.get("image_ai_settings"))
model = _extract_model(global_record, settings_json)
if from_wx_id.endswith("@chatroom"):
override_fields = "image_ai_enabled, image_ai_settings"
if chatroom_has_model:
override_fields += ", image_recognition_model"
override = _query_one(
conn,
f"SELECT {override_fields} FROM chat_room_settings WHERE chat_room_id = %s LIMIT 1",
(from_wx_id,),
)
else:
override_fields = "image_ai_enabled, image_ai_settings"
if friend_has_model:
override_fields += ", image_recognition_model"
override = _query_one(
conn,
f"SELECT {override_fields} FROM friend_settings WHERE wechat_id = %s LIMIT 1",
(from_wx_id,),
)
if override:
if override.get("image_ai_enabled") is not None:
enabled = bool(override["image_ai_enabled"])
override_settings = _decode_settings(override.get("image_ai_settings"))
if override_settings:
settings_json = override_settings
override_model = _extract_model(override, settings_json)
if override_model:
model = override_model
return enabled, model
def _http_post_json(url: str, body: dict, headers: dict, timeout: int = 300) -> dict:
data = json.dumps(body).encode("utf-8")
req = urllib.request.Request(url, data=data, headers=headers, method="POST")
try:
with urllib.request.urlopen(req, timeout=timeout) as resp:
return json.loads(resp.read().decode("utf-8"))
except urllib.error.HTTPError as exc:
error_body = exc.read().decode("utf-8", errors="replace")
raise RuntimeError(f"HTTP {exc.code}: {error_body}") from exc
except urllib.error.URLError as exc:
raise RuntimeError(str(exc)) from exc
def _extract_response_text(payload: dict) -> str:
choices = payload.get("choices", [])
if not choices:
return ""
message = choices[0].get("message", {})
content = message.get("content", "")
if isinstance(content, str):
return content.strip()
if isinstance(content, list):
texts: list[str] = []
for item in content:
if not isinstance(item, dict):
continue
if item.get("type") == "text" and isinstance(item.get("text"), str):
texts.append(item["text"].strip())
return "\n".join(text for text in texts if text)
return ""
def analyze_video(video_url: str, prompt: str, model: str, fps: int, max_tokens: int) -> str:
api_key = os.environ.get("ARK_API_KEY", "").strip()
if not api_key:
raise RuntimeError("环境变量 ARK_API_KEY 未配置")
if not model:
raise RuntimeError("数据库中未配置 image_recognition_model")
body = {
"model": model,
"messages": [
{
"role": "user",
"content": [
{"type": "video_url", "video_url": {"url": video_url}, "fps": str(fps)},
{"type": "text", "text": prompt},
],
}
],
"max_tokens": max_tokens,
}
response = _http_post_json(
"https://ark.cn-beijing.volces.com/api/v3/chat/completions",
body,
{"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
timeout=300,
)
text = _extract_response_text(response)
if not text:
raise RuntimeError("视频理解接口未返回文本内容")
return text
def _validate_video_url(value: str) -> str:
parsed = urlparse(value)
if parsed.scheme != "https" or not parsed.netloc:
raise ValueError("video_url 必须是 https 链接")
return value
def _parse_cli_params(argv: list[str]) -> dict:
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument("--video_url", default="")
parser.add_argument("--prompt", default=DEFAULT_PROMPT)
parser.add_argument("--fps", type=int, default=DEFAULT_FPS)
parser.add_argument("--max_tokens", type=int, default=DEFAULT_MAX_TOKENS)
namespace, unknown = parser.parse_known_args(argv)
if unknown:
raise ValueError(f"存在不支持的参数: {' '.join(unknown)}")
if namespace.fps <= 0:
raise ValueError("fps 必须大于 0")
if namespace.max_tokens <= 0:
raise ValueError("max_tokens 必须大于 0")
return {
"video_url": namespace.video_url,
"prompt": namespace.prompt,
"fps": namespace.fps,
"max_tokens": namespace.max_tokens,
}
def main() -> int:
if len(sys.argv) < 2:
sys.stdout.write("缺少输入参数\n")
return 1
try:
params = _parse_cli_params(sys.argv[1:])
except ValueError as exc:
sys.stdout.write(f"参数格式错误: {exc}\n")
return 1
video_url = params.get("video_url", "").strip()
if not video_url:
sys.stdout.write("缺少视频链接\n")
return 1
try:
_validate_video_url(video_url)
except ValueError as exc:
sys.stdout.write(f"参数格式错误: {exc}\n")
return 1
prompt = params.get("prompt", "").strip() or DEFAULT_PROMPT
fps = int(params.get("fps", DEFAULT_FPS))
max_tokens = int(params.get("max_tokens", DEFAULT_MAX_TOKENS))
from_wx_id = os.environ.get("ROBOT_FROM_WX_ID", "").strip()
if not from_wx_id:
sys.stdout.write("环境变量 ROBOT_FROM_WX_ID 未配置\n")
return 1
try:
conn = _mysql_connect()
except Exception as exc:
sys.stdout.write(f"数据库连接失败: {exc}\n")
return 1
try:
enabled, model = load_understanding_settings(conn, from_wx_id)
except Exception as exc:
sys.stdout.write(f"加载视频理解配置失败: {exc}\n")
return 1
finally:
try:
conn.close()
except Exception:
pass
if not enabled:
sys.stdout.write("AI 图像识别未开启\n")
return 0
try:
content = analyze_video(video_url, prompt, model, fps, max_tokens)
except Exception as exc:
sys.stdout.write(f"调用视频理解接口失败: {exc}\n")
return 1
sys.stdout.write(f"{content}\n")
return 0
if __name__ == "__main__":
try:
raise SystemExit(main())
except SystemExit:
raise
except Exception:
traceback.print_exc(file=sys.stdout)
raise SystemExit(1)