#!/usr/bin/env python3 from __future__ import annotations import argparse import base64 import json import mimetypes import os import re import subprocess import sys import tempfile import time import traceback import urllib.parse import urllib.request from pathlib import Path # The skill runner consumes stdout, so route Python error output there as well. sys.stderr = sys.stdout def _skill_root() -> Path: script_dir = Path(__file__).resolve().parent return script_dir.parent def _skill_venv_python() -> Path: venv_dir = _skill_root() / ".venv" if sys.platform == "win32": return venv_dir / "Scripts" / "python.exe" return venv_dir / "bin" / "python" def _get_python_executable() -> str: if sys.executable: return sys.executable import shutil for candidate in ("python3", "python"): found = shutil.which(candidate) if found: return found raise RuntimeError("无法找到 Python 解释器路径") def _run_bootstrap() -> None: bootstrap = Path(__file__).resolve().parent / "bootstrap.py" result = subprocess.run([_get_python_executable(), str(bootstrap)]) if result.returncode != 0: raise SystemExit(result.returncode) def _ensure_skill_venv_python() -> None: venv_python = _skill_venv_python() if not venv_python.is_file(): _run_bootstrap() venv_python = _skill_venv_python() if not venv_python.is_file(): sys.stdout.write("bootstrap 后仍未找到虚拟环境\n") raise SystemExit(1) venv_dir = _skill_root() / ".venv" if Path(sys.prefix) == venv_dir.resolve(): return os.execv(str(venv_python), [str(venv_python), str(Path(__file__).resolve()), *sys.argv[1:]]) _ensure_skill_venv_python() try: import pymysql # type: ignore # noqa: E402 from openai import OpenAI # type: ignore # noqa: E402 except ModuleNotFoundError: _run_bootstrap() _py = _get_python_executable() os.execv(_py, [_py, str(Path(__file__).resolve()), *sys.argv[1:]]) # --------------------------------------------------------------------------- # Database helpers # --------------------------------------------------------------------------- def _mysql_connect(): host = os.environ.get("MYSQL_HOST", "127.0.0.1") port = int(os.environ.get("MYSQL_PORT", "3306")) user = os.environ.get("MYSQL_USER", "root") password = os.environ.get("MYSQL_PASSWORD", "") database = os.environ.get("ROBOT_CODE", "") if not database: raise RuntimeError("环境变量 ROBOT_CODE 未配置") return pymysql.connect( host=host, port=port, user=user, password=password, database=database, charset="utf8mb4", connect_timeout=10, read_timeout=30, ) def _query_one(conn, sql: str, params: tuple = ()) -> dict | None: cur = conn.cursor() cur.execute(sql, params) columns = [desc[0] for desc in cur.description] if cur.description else [] row = cur.fetchone() cur.close() if row is None: return None return dict(zip(columns, row)) # --------------------------------------------------------------------------- # Settings resolution (mirrors the Go service logic) # --------------------------------------------------------------------------- def load_drawing_settings(conn, from_wx_id: str) -> tuple[bool, dict]: """Return (enabled, image_ai_settings_dict).""" # 1. global_settings gs = _query_one(conn, "SELECT image_ai_enabled, image_ai_settings FROM global_settings LIMIT 1") enabled = False settings_json: dict = {} if gs: if gs.get("image_ai_enabled"): enabled = bool(gs["image_ai_enabled"]) raw = gs.get("image_ai_settings") if raw: if isinstance(raw, (bytes, bytearray)): raw = raw.decode("utf-8") if isinstance(raw, str) and raw.strip(): settings_json = json.loads(raw) # 2. override from chatroom / friend settings if from_wx_id.endswith("@chatroom"): override = _query_one( conn, "SELECT image_ai_enabled, image_ai_settings FROM chat_room_settings WHERE chat_room_id = %s LIMIT 1", (from_wx_id,), ) else: override = _query_one( conn, "SELECT image_ai_enabled, image_ai_settings FROM friend_settings WHERE wechat_id = %s LIMIT 1", (from_wx_id,), ) if override: if override.get("image_ai_enabled") is not None: enabled = bool(override["image_ai_enabled"]) raw = override.get("image_ai_settings") if raw: if isinstance(raw, (bytes, bytearray)): raw = raw.decode("utf-8") if isinstance(raw, str) and raw.strip(): settings_json = json.loads(raw) return enabled, settings_json # --------------------------------------------------------------------------- # API callers # --------------------------------------------------------------------------- def _http_post_json(url: str, body: dict, headers: dict, timeout: int = 300) -> dict: data = json.dumps(body).encode("utf-8") req = urllib.request.Request(url, data=data, headers=headers, method="POST") with urllib.request.urlopen(req, timeout=timeout) as resp: return json.loads(resp.read().decode("utf-8")) def _http_get_json(url: str, headers: dict, timeout: int = 30) -> dict: req = urllib.request.Request(url, headers=headers, method="GET") with urllib.request.urlopen(req, timeout=timeout) as resp: return json.loads(resp.read().decode("utf-8")) def _coerce_int(value, default: int, minimum: int, maximum: int) -> int: try: parsed = int(value) except (TypeError, ValueError): parsed = default return min(max(parsed, minimum), maximum) def _openai_output_format(config: dict) -> str: output_format = str(config.get("output_format", "png") or "png").lower() if output_format not in {"png", "jpeg", "webp"}: return "png" return output_format def _openai_size(config: dict, ratio: str, resolution: str) -> str: configured = str(config.get("size", "") or "").strip() if configured: return configured normalized_ratio = (ratio or "").replace(" ", "").lower() normalized_resolution = (resolution or "").replace(" ", "").lower() if normalized_resolution in {"4k", "2160p", "3840x2160"}: sizes = { "16:9": "3840x2160", "9:16": "2160x3840", "1:1": "2048x2048", "3:2": "3072x2048", "2:3": "2048x3072", } elif normalized_resolution in {"2k", "1440p", "2048"}: sizes = { "16:9": "2048x1152", "9:16": "1152x2048", "1:1": "2048x2048", "3:2": "2048x1360", "2:3": "1360x2048", } elif normalized_resolution in {"1k", "1024", "1024p"}: sizes = { "16:9": "1536x864", "9:16": "864x1536", "1:1": "1024x1024", "3:2": "1536x1024", "2:3": "1024x1536", } else: return "auto" return sizes.get(normalized_ratio, "auto") def _openai_prompt(prompt: str, negative_prompt: str) -> str: if not negative_prompt: return prompt return f"{prompt}\n\n不要包含: {negative_prompt}" def _openai_client(config: dict) -> OpenAI: api_key = str(config.get("api_key", "")).strip() if not api_key: raise RuntimeError("OpenAI 绘图配置缺少 api_key") base_url = str(config.get("base_url", "") or "").strip() organization = str(config.get("organization", "") or "").strip() project = str(config.get("project", "") or "").strip() timeout: float | None = None timeout_value = config.get("timeout") if timeout_value not in (None, ""): timeout = float(timeout_value) return OpenAI( api_key=api_key, base_url=base_url or None, organization=organization or None, project=project or None, timeout=timeout, ) def _truncate_debug_payload(value): if isinstance(value, dict): return { key: ( f"{item[:50]}..." if key == "b64_json" and isinstance(item, str) and len(item) > 50 else _truncate_debug_payload(item) ) for key, item in value.items() } if isinstance(value, list): return [_truncate_debug_payload(item) for item in value] return value def _debug_response(label: str, payload) -> None: if os.environ.get("SKILL_DEBUG_LOG", "").strip().lower() not in {"true", "1"}: return if hasattr(payload, "model_dump"): payload = payload.model_dump() payload = _truncate_debug_payload(payload) sys.stdout.write(f"[debug] {label}: {json.dumps(payload, ensure_ascii=False)}\n") def _rewrite_openai_image_url(url: str) -> str: internal_host = "http://chatgpt2api:80" external_host = "https://chatgpt2api.houhoukang.com" if url.startswith(internal_host): return f"{external_host}{url[len(internal_host):]}" return url def _extension_from_mime(mime_type: str) -> str: if mime_type == "image/jpeg": return ".jpg" guessed = mimetypes.guess_extension(mime_type) if guessed in {".png", ".jpg", ".jpeg", ".webp"}: return guessed return ".png" def _extension_from_output_format(output_format: str) -> str: if output_format == "jpeg": return ".jpg" if output_format == "webp": return ".webp" return ".png" def _openai_response_value(item, key: str): if isinstance(item, dict): return item.get(key) return getattr(item, key, None) def _write_openai_b64_image(b64_json: str, output_format: str) -> str: encoded = b64_json.strip() suffix = _extension_from_output_format(output_format) if encoded.startswith("data:"): header, encoded = encoded.split(",", 1) mime_type = header[5:].split(";", 1)[0].strip().lower() if mime_type: suffix = _extension_from_mime(mime_type) encoded = "".join(encoded.split()) padding = len(encoded) % 4 if padding: encoded = f"{encoded}{'=' * (4 - padding)}" image_bytes = base64.b64decode(encoded) with tempfile.NamedTemporaryFile(prefix="wechat-openai-image-", suffix=suffix, delete=False) as temp_file: temp_file.write(image_bytes) return temp_file.name def _openai_images_from_response(response, output_format: str) -> list[str]: outputs: list[str] = [] try: for item in getattr(response, "data", []) or []: b64_json = _openai_response_value(item, "b64_json") if b64_json: outputs.append(_write_openai_b64_image(str(b64_json), output_format)) continue url = _openai_response_value(item, "url") if url: outputs.append(_rewrite_openai_image_url(str(url))) except Exception: _cleanup_openai_temp_files(outputs) raise return outputs def _is_remote_image_url(value: str) -> bool: return urllib.parse.urlparse(value).scheme in {"http", "https"} def _send_image_outputs(client_port: str, from_wx_id: str, image_outputs: list[str]) -> None: remote_urls = [value for value in image_outputs if value and _is_remote_image_url(value)] local_paths = [value for value in image_outputs if value and not _is_remote_image_url(value)] if remote_urls: send_url = f"http://127.0.0.1:{client_port}/api/v1/robot/message/send/image/url" send_body = { "to_wxid": from_wx_id, "image_urls": remote_urls, } response = _http_post_json(send_url, send_body, {"Content-Type": "application/json"}, timeout=300) _debug_response("send image url response", response) for file_path in local_paths: send_url = f"http://127.0.0.1:{client_port}/api/v1/robot/message/send/image/local" send_body = { "to_wxid": from_wx_id, "file_path": file_path, } response = _http_post_json(send_url, send_body, {"Content-Type": "application/json"}, timeout=300) _debug_response("send image local response", response) def _cleanup_openai_temp_files(image_outputs: list[str]) -> None: for value in image_outputs: path = Path(value) if path.name.startswith("wechat-openai-image-") and path.is_file(): try: path.unlink() except OSError: pass def call_jimeng(config: dict, prompt: str, model: str, negative_prompt: str, ratio: str, resolution: str) -> list[str]: """Call JiMeng (即梦) image generation API.""" base_url = config.get("base_url", "").rstrip("/") session_ids = config.get("sessionid", []) if not base_url or not session_ids: raise RuntimeError("即梦绘图配置缺少 base_url 或 sessionid") if not model or model == "none": model = "jimeng-5.0" if not ratio: ratio = "16:9" if not resolution: resolution = "2k" # 如果分辨率大于4k,重置为2k m = re.search(r"(\d+)", resolution) if m and int(m.group(1)) > 4: resolution = "2k" token = ",".join(session_ids) body = { "model": model, "prompt": prompt, "ratio": ratio, "resolution": resolution, "response_format": "url", "sample_strength": 0.5, } if negative_prompt: body["negative_prompt"] = negative_prompt resp = _http_post_json( f"{base_url}/v1/images/generations", body, {"Content-Type": "application/json", "Authorization": f"Bearer {token}"}, timeout=300, ) urls = [item["url"] for item in resp.get("data", []) if item.get("url")] return urls def call_doubao(config: dict, prompt: str, model: str) -> list[str]: """Call DouBao (豆包) image generation API.""" api_key = config.get("api_key", "") if not api_key: raise RuntimeError("豆包绘图配置缺少 api_key") if not model or model == "none": model = "doubao-seedream-4.5" # Map friendly model names to actual endpoint model IDs model_map = { "doubao-seedream-4.5": "doubao-seedream-4-5-251128", "doubao-seedream-4.0": "doubao-seedream-4-0-251128", "doubao-seedream-3.0-t2i": "doubao-seedream-3-0-t2i-250415", "doubao-seededit-3.0-i2i": "doubao-seededit-3-0-i2i-250628", } actual_model = model_map.get(model, model) body = { "model": actual_model, "prompt": prompt, "response_format": "url", "size": config.get("size", "2K"), "sequential_image_generation": config.get("sequential_image_generation", "auto"), "watermark": config.get("watermark", False), } image_val = config.get("image", "") if image_val: body["image"] = image_val resp = _http_post_json( "https://ark.cn-beijing.volces.com/api/v3/images/generations", body, {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, timeout=300, ) urls = [] for item in resp.get("data", []): url = item.get("url") if url: urls.append(url) return urls def call_zimage(config: dict, prompt: str, model: str) -> list[str]: """Call Z-Image (造相) image generation API (async task-based).""" base_url = config.get("base_url", "").rstrip("/") api_key = config.get("api_key", "") if not base_url or not api_key: raise RuntimeError("造相绘图配置缺少 base_url 或 api_key") if not model or model == "none": model = "Z-Image-Turbo" # Map model names model_map = { "Z-Image": "Tongyi-MAI/Z-Image", "Z-Image-Turbo": "Tongyi-MAI/Z-Image-Turbo", "Qwen-Image-Edit-2511": "Qwen/Qwen-Image-Edit-2511", } actual_model = model_map.get(model) if actual_model is None: raise RuntimeError(f"不支持的造相模型: {model}") body = { "model": actual_model, "prompt": prompt, "image_url": config.get("image_url", []), } headers = { "Content-Type": "application/json", "Authorization": f"Bearer {api_key}", "X-ModelScope-Async-Mode": "true", } # Step 1: create task resp = _http_post_json(f"{base_url}/v1/images/generations", body, headers, timeout=30) task_id = resp.get("task_id", "") if not task_id: raise RuntimeError("造相接口未返回 task_id") # Step 2: poll for result poll_headers = { "Content-Type": "application/json", "Authorization": f"Bearer {api_key}", "X-ModelScope-Task-Type": "image_generation", } deadline = time.time() + 15 * 60 # 15 minutes while time.time() < deadline: task_resp = _http_get_json(f"{base_url}/v1/tasks/{task_id}", poll_headers, timeout=30) status = task_resp.get("task_status", "") if status == "SUCCEED": images = task_resp.get("output_images", []) if images: return images raise RuntimeError("造相任务成功但未返回图片") if status == "FAILED": raise RuntimeError("造相绘图任务失败") time.sleep(5) raise RuntimeError("造相绘图任务超时") def call_openai(config: dict, prompt: str, model: str, negative_prompt: str, ratio: str, resolution: str) -> list[str]: """Call OpenAI GPT Image API for text-to-image generation.""" client = _openai_client(config) output_format = _openai_output_format(config) quality = str(config.get("quality", "auto") or "auto") moderation = str(config.get("moderation", "auto") or "auto") background = str(config.get("background", "auto") or "auto") if background == "transparent": background = "auto" kwargs = { "model": model or "gpt-image-2", "prompt": _openai_prompt(prompt, negative_prompt), "n": _coerce_int(config.get("n"), 1, 1, 10), "size": _openai_size(config, ratio, resolution), "quality": quality, "background": background, "moderation": moderation, "output_format": output_format, } if output_format in {"jpeg", "webp"} and config.get("output_compression") is not None: kwargs["output_compression"] = _coerce_int(config.get("output_compression"), 100, 0, 100) response = client.images.generate(**kwargs) _debug_response("openai images.generate response", response) return _openai_images_from_response(response, output_format) # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- JIMENG_MODELS = {"jimeng-4.5", "jimeng-4.6", "jimeng-4.7", "jimeng-5.0"} DOUBAO_MODELS = {"doubao-seedream-4.5", "doubao-seedream-4.0", "doubao-seedream-3.0-t2i", "doubao-seededit-3.0-i2i"} ZIMAGE_MODELS = {"Z-Image", "Z-Image-Turbo", "Qwen-Image-Edit-2511"} OPENAI_MODELS = {"gpt-image-2"} def _parse_cli_params(argv: list[str]) -> dict[str, str]: parser = argparse.ArgumentParser(add_help=False) parser.add_argument("--prompt", default="") parser.add_argument("--model", default="") parser.add_argument("--negative_prompt", default="") parser.add_argument("--ratio", default="") parser.add_argument("--resolution", default="") namespace, unknown = parser.parse_known_args(argv) if unknown: raise ValueError(f"存在不支持的参数: {' '.join(unknown)}") return { "prompt": namespace.prompt, "model": namespace.model, "negative_prompt": namespace.negative_prompt, "ratio": namespace.ratio, "resolution": namespace.resolution, } def main() -> int: if len(sys.argv) < 2: sys.stdout.write("缺少输入参数\n") return 1 try: params = _parse_cli_params(sys.argv[1:]) except ValueError as exc: sys.stdout.write(f"参数格式错误: {exc}\n") return 1 prompt = params.get("prompt", "").strip() if not prompt: sys.stdout.write("缺少画图提示词\n") return 1 model = params.get("model", "").strip() negative_prompt = params.get("negative_prompt", "").strip() ratio = params.get("ratio", "").strip() resolution = params.get("resolution", "").strip() from_wx_id = os.environ.get("ROBOT_FROM_WX_ID", "").strip() if not from_wx_id: sys.stdout.write("环境变量 ROBOT_FROM_WX_ID 未配置\n") return 1 # Connect to DB and load settings try: conn = _mysql_connect() except Exception as exc: sys.stdout.write(f"数据库连接失败: {exc}\n") return 1 try: enabled, settings_json = load_drawing_settings(conn, from_wx_id) except Exception as exc: conn.close() sys.stdout.write(f"加载绘图配置失败: {exc}\n") return 1 finally: try: conn.close() except Exception: pass if not enabled: sys.stdout.write("AI 绘图未开启\n") return 0 # Default model if not model or model == "none": model = "jimeng-5.0" # Route to correct API try: image_urls: list[str] = [] if model in JIMENG_MODELS: jimeng_config = settings_json.get("JiMeng", {}) if not jimeng_config.get("enabled", False): sys.stdout.write("即梦绘图未开启\n") return 0 image_urls = call_jimeng(jimeng_config, prompt, model, negative_prompt, ratio, resolution) elif model in DOUBAO_MODELS: doubao_config = settings_json.get("DouBao", {}) if not doubao_config.get("enabled", False): sys.stdout.write("豆包绘图未开启\n") return 0 image_urls = call_doubao(doubao_config, prompt, model) elif model in ZIMAGE_MODELS: zimage_config = settings_json.get("Z-Image", {}) if not zimage_config.get("enabled", False): sys.stdout.write("造相绘图未开启\n") return 0 image_urls = call_zimage(zimage_config, prompt, model) elif model in OPENAI_MODELS: openai_config = settings_json.get("OpenAI", {}) if not openai_config.get("enabled", False): sys.stdout.write("OpenAI 绘图未开启\n") return 0 image_urls = call_openai(openai_config, prompt, model, negative_prompt, ratio, resolution) else: sys.stdout.write("不支持的 AI 图像模型\n") return 1 except Exception as exc: sys.stdout.write(f"调用绘图接口失败: {exc}\n") return 1 if not image_urls: sys.stdout.write("未生成任何图像\n") return 1 # 通过客户端接口发送图片 client_port = os.environ.get("ROBOT_WECHAT_CLIENT_PORT", "").strip() if not client_port: _cleanup_openai_temp_files(image_urls) sys.stdout.write("环境变量 ROBOT_WECHAT_CLIENT_PORT 未配置\n") return 1 try: _send_image_outputs(client_port, from_wx_id, image_urls) sys.stdout.write("图片发送成功\n") except Exception as exc: sys.stdout.write(f"发送图片失败: {exc}\n") return 1 finally: _cleanup_openai_temp_files(image_urls) return 0 if __name__ == "__main__": try: raise SystemExit(main()) except SystemExit: raise except Exception: traceback.print_exc(file=sys.stdout) raise SystemExit(1)