#!/usr/bin/env python3 from __future__ import annotations import argparse import json import os import re import subprocess import sys import time import traceback import urllib.parse import urllib.request from pathlib import Path # The skill runner consumes stdout, so route Python error output there as well. sys.stderr = sys.stdout def _skill_root() -> Path: script_dir = Path(__file__).resolve().parent return script_dir.parent def _skill_venv_python() -> Path: venv_dir = _skill_root() / ".venv" if sys.platform == "win32": return venv_dir / "Scripts" / "python.exe" return venv_dir / "bin" / "python" def _get_python_executable() -> str: if sys.executable: return sys.executable import shutil for candidate in ("python3", "python"): found = shutil.which(candidate) if found: return found raise RuntimeError("无法找到 Python 解释器路径") def _run_bootstrap() -> None: bootstrap = Path(__file__).resolve().parent / "bootstrap.py" result = subprocess.run([_get_python_executable(), str(bootstrap)]) if result.returncode != 0: raise SystemExit(result.returncode) def _ensure_skill_venv_python() -> None: venv_python = _skill_venv_python() if not venv_python.is_file(): _run_bootstrap() venv_python = _skill_venv_python() if not venv_python.is_file(): sys.stdout.write("bootstrap 后仍未找到虚拟环境\n") raise SystemExit(1) venv_dir = _skill_root() / ".venv" if Path(sys.prefix) == venv_dir.resolve(): return os.execv(str(venv_python), [str(venv_python), str(Path(__file__).resolve()), *sys.argv[1:]]) _ensure_skill_venv_python() try: import pymysql # type: ignore # noqa: E402 from openai import OpenAI # type: ignore # noqa: E402 except ModuleNotFoundError: _run_bootstrap() _py = _get_python_executable() os.execv(_py, [_py, str(Path(__file__).resolve()), *sys.argv[1:]]) # --------------------------------------------------------------------------- # Database helpers # --------------------------------------------------------------------------- def _mysql_connect(): host = os.environ.get("MYSQL_HOST", "127.0.0.1") port = int(os.environ.get("MYSQL_PORT", "3306")) user = os.environ.get("MYSQL_USER", "root") password = os.environ.get("MYSQL_PASSWORD", "") database = os.environ.get("ROBOT_CODE", "") if not database: raise RuntimeError("环境变量 ROBOT_CODE 未配置") return pymysql.connect( host=host, port=port, user=user, password=password, database=database, charset="utf8mb4", connect_timeout=10, read_timeout=30, ) def _query_one(conn, sql: str, params: tuple = ()) -> dict | None: cur = conn.cursor() cur.execute(sql, params) columns = [desc[0] for desc in cur.description] if cur.description else [] row = cur.fetchone() cur.close() if row is None: return None return dict(zip(columns, row)) # --------------------------------------------------------------------------- # Settings resolution (mirrors the Go service logic) # --------------------------------------------------------------------------- def load_drawing_settings(conn, from_wx_id: str) -> tuple[bool, dict]: """Return (enabled, image_ai_settings_dict).""" # 1. global_settings gs = _query_one(conn, "SELECT image_ai_enabled, image_ai_settings FROM global_settings LIMIT 1") enabled = False settings_json: dict = {} if gs: if gs.get("image_ai_enabled"): enabled = bool(gs["image_ai_enabled"]) raw = gs.get("image_ai_settings") if raw: if isinstance(raw, (bytes, bytearray)): raw = raw.decode("utf-8") if isinstance(raw, str) and raw.strip(): settings_json = json.loads(raw) # 2. override from chatroom / friend settings if from_wx_id.endswith("@chatroom"): override = _query_one( conn, "SELECT image_ai_enabled, image_ai_settings FROM chat_room_settings WHERE chat_room_id = %s LIMIT 1", (from_wx_id,), ) else: override = _query_one( conn, "SELECT image_ai_enabled, image_ai_settings FROM friend_settings WHERE wechat_id = %s LIMIT 1", (from_wx_id,), ) if override: if override.get("image_ai_enabled") is not None: enabled = bool(override["image_ai_enabled"]) raw = override.get("image_ai_settings") if raw: if isinstance(raw, (bytes, bytearray)): raw = raw.decode("utf-8") if isinstance(raw, str) and raw.strip(): settings_json = json.loads(raw) return enabled, settings_json # --------------------------------------------------------------------------- # API callers # --------------------------------------------------------------------------- def _http_post_json(url: str, body: dict, headers: dict, timeout: int = 300) -> dict: data = json.dumps(body).encode("utf-8") req = urllib.request.Request(url, data=data, headers=headers, method="POST") with urllib.request.urlopen(req, timeout=timeout) as resp: return json.loads(resp.read().decode("utf-8")) def _http_get_json(url: str, headers: dict, timeout: int = 30) -> dict: req = urllib.request.Request(url, headers=headers, method="GET") with urllib.request.urlopen(req, timeout=timeout) as resp: return json.loads(resp.read().decode("utf-8")) def _coerce_int(value, default: int, minimum: int, maximum: int) -> int: try: parsed = int(value) except (TypeError, ValueError): parsed = default return min(max(parsed, minimum), maximum) def _openai_output_format(config: dict) -> str: output_format = str(config.get("output_format", "png") or "png").lower() if output_format not in {"png", "jpeg", "webp"}: return "png" return output_format def _openai_size(config: dict, ratio: str, resolution: str) -> str: configured = str(config.get("size", "") or "").strip() if configured: return configured normalized_ratio = (ratio or "").replace(" ", "").lower() normalized_resolution = (resolution or "").replace(" ", "").lower() if normalized_resolution in {"4k", "2160p", "3840x2160"}: sizes = { "16:9": "3840x2160", "9:16": "2160x3840", "1:1": "2048x2048", "3:2": "3072x2048", "2:3": "2048x3072", } elif normalized_resolution in {"2k", "1440p", "2048"}: sizes = { "16:9": "2048x1152", "9:16": "1152x2048", "1:1": "2048x2048", "3:2": "2048x1360", "2:3": "1360x2048", } elif normalized_resolution in {"1k", "1024", "1024p"}: sizes = { "16:9": "1536x864", "9:16": "864x1536", "1:1": "1024x1024", "3:2": "1536x1024", "2:3": "1024x1536", } else: return "auto" return sizes.get(normalized_ratio, "auto") def _openai_prompt(prompt: str, negative_prompt: str) -> str: if not negative_prompt: return prompt return f"{prompt}\n\n不要包含: {negative_prompt}" def _openai_client(config: dict) -> OpenAI: api_key = str(config.get("api_key", "")).strip() if not api_key: raise RuntimeError("OpenAI 绘图配置缺少 api_key") base_url = str(config.get("base_url", "") or "").strip() organization = str(config.get("organization", "") or "").strip() project = str(config.get("project", "") or "").strip() timeout: float | None = None timeout_value = config.get("timeout") if timeout_value not in (None, ""): timeout = float(timeout_value) return OpenAI( api_key=api_key, base_url=base_url or None, organization=organization or None, project=project or None, timeout=timeout, ) def _truncate_debug_payload(value): if isinstance(value, dict): return { key: ( f"{item[:50]}..." if key == "b64_json" and isinstance(item, str) and len(item) > 50 else _truncate_debug_payload(item) ) for key, item in value.items() } if isinstance(value, list): return [_truncate_debug_payload(item) for item in value] return value def _debug_response(label: str, payload) -> None: if hasattr(payload, "model_dump"): payload = payload.model_dump() payload = _truncate_debug_payload(payload) sys.stdout.write(f"[debug] {label}: {json.dumps(payload, ensure_ascii=False)}\n") def _rewrite_openai_image_url(url: str) -> str: internal_host = "http://chatgpt2api:80" external_host = "http://chatgpt2api.houhoukang.com" if url.startswith(internal_host): return f"{external_host}{url[len(internal_host):]}" return url def _openai_images_from_response(response) -> list[str]: outputs: list[str] = [] for item in getattr(response, "data", []) or []: url = getattr(item, "url", None) if url: outputs.append(_rewrite_openai_image_url(str(url))) return outputs def _is_remote_image_url(value: str) -> bool: return urllib.parse.urlparse(value).scheme in {"http", "https"} def _send_image_outputs(client_port: str, from_wx_id: str, image_outputs: list[str]) -> None: remote_urls = [value for value in image_outputs if value and _is_remote_image_url(value)] local_paths = [value for value in image_outputs if value and not _is_remote_image_url(value)] if remote_urls: send_url = f"http://127.0.0.1:{client_port}/api/v1/robot/message/send/image/url" send_body = { "to_wxid": from_wx_id, "image_urls": remote_urls, } response = _http_post_json(send_url, send_body, {"Content-Type": "application/json"}, timeout=60) _debug_response("send image url response", response) if local_paths: send_url = f"http://127.0.0.1:{client_port}/api/v1/robot/message/send/image/local" send_body = { "to_wxid": from_wx_id, "file_path": local_paths, } response = _http_post_json(send_url, send_body, {"Content-Type": "application/json"}, timeout=60) _debug_response("send image local response", response) def _cleanup_openai_temp_files(image_outputs: list[str]) -> None: for value in image_outputs: path = Path(value) if path.name.startswith("wechat-openai-image-") and path.is_file(): try: path.unlink() except OSError: pass def call_jimeng(config: dict, prompt: str, model: str, negative_prompt: str, ratio: str, resolution: str) -> list[str]: """Call JiMeng (即梦) image generation API.""" base_url = config.get("base_url", "").rstrip("/") session_ids = config.get("sessionid", []) if not base_url or not session_ids: raise RuntimeError("即梦绘图配置缺少 base_url 或 sessionid") if not model or model == "none": model = "jimeng-5.0" if not ratio: ratio = "16:9" if not resolution: resolution = "2k" # 如果分辨率大于4k,重置为2k m = re.search(r"(\d+)", resolution) if m and int(m.group(1)) > 4: resolution = "2k" token = ",".join(session_ids) body = { "model": model, "prompt": prompt, "ratio": ratio, "resolution": resolution, "response_format": "url", "sample_strength": 0.5, } if negative_prompt: body["negative_prompt"] = negative_prompt resp = _http_post_json( f"{base_url}/v1/images/generations", body, {"Content-Type": "application/json", "Authorization": f"Bearer {token}"}, timeout=300, ) urls = [item["url"] for item in resp.get("data", []) if item.get("url")] return urls def call_doubao(config: dict, prompt: str, model: str) -> list[str]: """Call DouBao (豆包) image generation API.""" api_key = config.get("api_key", "") if not api_key: raise RuntimeError("豆包绘图配置缺少 api_key") if not model or model == "none": model = "doubao-seedream-4.5" # Map friendly model names to actual endpoint model IDs model_map = { "doubao-seedream-4.5": "doubao-seedream-4-5-251128", "doubao-seedream-4.0": "doubao-seedream-4-0-251128", "doubao-seedream-3.0-t2i": "doubao-seedream-3-0-t2i-250415", "doubao-seededit-3.0-i2i": "doubao-seededit-3-0-i2i-250628", } actual_model = model_map.get(model, model) body = { "model": actual_model, "prompt": prompt, "response_format": "url", "size": config.get("size", "2K"), "sequential_image_generation": config.get("sequential_image_generation", "auto"), "watermark": config.get("watermark", False), } image_val = config.get("image", "") if image_val: body["image"] = image_val resp = _http_post_json( "https://ark.cn-beijing.volces.com/api/v3/images/generations", body, {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, timeout=300, ) urls = [] for item in resp.get("data", []): url = item.get("url") if url: urls.append(url) return urls def call_zimage(config: dict, prompt: str, model: str) -> list[str]: """Call Z-Image (造相) image generation API (async task-based).""" base_url = config.get("base_url", "").rstrip("/") api_key = config.get("api_key", "") if not base_url or not api_key: raise RuntimeError("造相绘图配置缺少 base_url 或 api_key") if not model or model == "none": model = "Z-Image-Turbo" # Map model names model_map = { "Z-Image": "Tongyi-MAI/Z-Image", "Z-Image-Turbo": "Tongyi-MAI/Z-Image-Turbo", "Qwen-Image-Edit-2511": "Qwen/Qwen-Image-Edit-2511", } actual_model = model_map.get(model) if actual_model is None: raise RuntimeError(f"不支持的造相模型: {model}") body = { "model": actual_model, "prompt": prompt, "image_url": config.get("image_url", []), } headers = { "Content-Type": "application/json", "Authorization": f"Bearer {api_key}", "X-ModelScope-Async-Mode": "true", } # Step 1: create task resp = _http_post_json(f"{base_url}/v1/images/generations", body, headers, timeout=30) task_id = resp.get("task_id", "") if not task_id: raise RuntimeError("造相接口未返回 task_id") # Step 2: poll for result poll_headers = { "Content-Type": "application/json", "Authorization": f"Bearer {api_key}", "X-ModelScope-Task-Type": "image_generation", } deadline = time.time() + 15 * 60 # 15 minutes while time.time() < deadline: task_resp = _http_get_json(f"{base_url}/v1/tasks/{task_id}", poll_headers, timeout=30) status = task_resp.get("task_status", "") if status == "SUCCEED": images = task_resp.get("output_images", []) if images: return images raise RuntimeError("造相任务成功但未返回图片") if status == "FAILED": raise RuntimeError("造相绘图任务失败") time.sleep(5) raise RuntimeError("造相绘图任务超时") def call_openai(config: dict, prompt: str, model: str, negative_prompt: str, ratio: str, resolution: str) -> list[str]: """Call OpenAI GPT Image API for text-to-image generation.""" client = _openai_client(config) output_format = _openai_output_format(config) quality = str(config.get("quality", "auto") or "auto") moderation = str(config.get("moderation", "auto") or "auto") background = str(config.get("background", "auto") or "auto") if background == "transparent": background = "auto" kwargs = { "model": model or "gpt-image-2", "prompt": _openai_prompt(prompt, negative_prompt), "n": _coerce_int(config.get("n"), 1, 1, 10), "size": _openai_size(config, ratio, resolution), "quality": quality, "background": background, "moderation": moderation, "output_format": output_format, } if output_format in {"jpeg", "webp"} and config.get("output_compression") is not None: kwargs["output_compression"] = _coerce_int(config.get("output_compression"), 100, 0, 100) response = client.images.generate(**kwargs) _debug_response("openai images.generate response", response) return _openai_images_from_response(response) # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- JIMENG_MODELS = {"jimeng-4.5", "jimeng-4.6", "jimeng-5.0"} DOUBAO_MODELS = {"doubao-seedream-4.5", "doubao-seedream-4.0", "doubao-seedream-3.0-t2i", "doubao-seededit-3.0-i2i"} ZIMAGE_MODELS = {"Z-Image", "Z-Image-Turbo", "Qwen-Image-Edit-2511"} OPENAI_MODELS = {"gpt-image-2"} def _parse_cli_params(argv: list[str]) -> dict[str, str]: parser = argparse.ArgumentParser(add_help=False) parser.add_argument("--prompt", default="") parser.add_argument("--model", default="") parser.add_argument("--negative_prompt", default="") parser.add_argument("--ratio", default="") parser.add_argument("--resolution", default="") namespace, unknown = parser.parse_known_args(argv) if unknown: raise ValueError(f"存在不支持的参数: {' '.join(unknown)}") return { "prompt": namespace.prompt, "model": namespace.model, "negative_prompt": namespace.negative_prompt, "ratio": namespace.ratio, "resolution": namespace.resolution, } def main() -> int: if len(sys.argv) < 2: sys.stdout.write("缺少输入参数\n") return 1 try: params = _parse_cli_params(sys.argv[1:]) except ValueError as exc: sys.stdout.write(f"参数格式错误: {exc}\n") return 1 prompt = params.get("prompt", "").strip() if not prompt: sys.stdout.write("缺少画图提示词\n") return 1 model = params.get("model", "").strip() negative_prompt = params.get("negative_prompt", "").strip() ratio = params.get("ratio", "").strip() resolution = params.get("resolution", "").strip() from_wx_id = os.environ.get("ROBOT_FROM_WX_ID", "").strip() if not from_wx_id: sys.stdout.write("环境变量 ROBOT_FROM_WX_ID 未配置\n") return 1 # Connect to DB and load settings try: conn = _mysql_connect() except Exception as exc: sys.stdout.write(f"数据库连接失败: {exc}\n") return 1 try: enabled, settings_json = load_drawing_settings(conn, from_wx_id) except Exception as exc: conn.close() sys.stdout.write(f"加载绘图配置失败: {exc}\n") return 1 finally: try: conn.close() except Exception: pass if not enabled: sys.stdout.write("AI 绘图未开启\n") return 0 # Default model if not model or model == "none": model = "jimeng-5.0" # Route to correct API try: image_urls: list[str] = [] if model in JIMENG_MODELS: jimeng_config = settings_json.get("JiMeng", {}) if not jimeng_config.get("enabled", False): sys.stdout.write("即梦绘图未开启\n") return 0 image_urls = call_jimeng(jimeng_config, prompt, model, negative_prompt, ratio, resolution) elif model in DOUBAO_MODELS: doubao_config = settings_json.get("DouBao", {}) if not doubao_config.get("enabled", False): sys.stdout.write("豆包绘图未开启\n") return 0 image_urls = call_doubao(doubao_config, prompt, model) elif model in ZIMAGE_MODELS: zimage_config = settings_json.get("Z-Image", {}) if not zimage_config.get("enabled", False): sys.stdout.write("造相绘图未开启\n") return 0 image_urls = call_zimage(zimage_config, prompt, model) elif model in OPENAI_MODELS: openai_config = settings_json.get("OpenAI", {}) if not openai_config.get("enabled", False): sys.stdout.write("OpenAI 绘图未开启\n") return 0 image_urls = call_openai(openai_config, prompt, model, negative_prompt, ratio, resolution) else: sys.stdout.write("不支持的 AI 图像模型\n") return 1 except Exception as exc: sys.stdout.write(f"调用绘图接口失败: {exc}\n") return 1 if not image_urls: sys.stdout.write("未生成任何图像\n") return 1 # 通过客户端接口发送图片 client_port = os.environ.get("ROBOT_WECHAT_CLIENT_PORT", "").strip() if not client_port: _cleanup_openai_temp_files(image_urls) sys.stdout.write("环境变量 ROBOT_WECHAT_CLIENT_PORT 未配置\n") return 1 try: _send_image_outputs(client_port, from_wx_id, image_urls) sys.stdout.write("图片发送成功\n") except Exception as exc: sys.stdout.write(f"发送图片失败: {exc}\n") return 1 finally: _cleanup_openai_temp_files(image_urls) return 0 if __name__ == "__main__": try: raise SystemExit(main()) except SystemExit: raise except Exception: traceback.print_exc(file=sys.stdout) raise SystemExit(1)