wechat-robot-skills/skills/text-to-image/scripts/text_to_image.py
2026-05-05 14:15:32 +08:00

651 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
from __future__ import annotations
import argparse
import base64
import json
import os
import re
import subprocess
import sys
import tempfile
import time
import traceback
import urllib.parse
import urllib.request
from pathlib import Path
# The skill runner consumes stdout, so route Python error output there as well.
sys.stderr = sys.stdout
def _skill_root() -> Path:
script_dir = Path(__file__).resolve().parent
return script_dir.parent
def _skill_venv_python() -> Path:
venv_dir = _skill_root() / ".venv"
if sys.platform == "win32":
return venv_dir / "Scripts" / "python.exe"
return venv_dir / "bin" / "python"
def _get_python_executable() -> str:
if sys.executable:
return sys.executable
import shutil
for candidate in ("python3", "python"):
found = shutil.which(candidate)
if found:
return found
raise RuntimeError("无法找到 Python 解释器路径")
def _run_bootstrap() -> None:
bootstrap = Path(__file__).resolve().parent / "bootstrap.py"
result = subprocess.run([_get_python_executable(), str(bootstrap)])
if result.returncode != 0:
raise SystemExit(result.returncode)
def _ensure_skill_venv_python() -> None:
venv_python = _skill_venv_python()
if not venv_python.is_file():
_run_bootstrap()
venv_python = _skill_venv_python()
if not venv_python.is_file():
sys.stdout.write("bootstrap 后仍未找到虚拟环境\n")
raise SystemExit(1)
venv_dir = _skill_root() / ".venv"
if Path(sys.prefix) == venv_dir.resolve():
return
os.execv(str(venv_python), [str(venv_python), str(Path(__file__).resolve()), *sys.argv[1:]])
_ensure_skill_venv_python()
try:
import pymysql # type: ignore # noqa: E402
from openai import OpenAI # type: ignore # noqa: E402
except ModuleNotFoundError:
_run_bootstrap()
_py = _get_python_executable()
os.execv(_py, [_py, str(Path(__file__).resolve()), *sys.argv[1:]])
# ---------------------------------------------------------------------------
# Database helpers
# ---------------------------------------------------------------------------
def _mysql_connect():
host = os.environ.get("MYSQL_HOST", "127.0.0.1")
port = int(os.environ.get("MYSQL_PORT", "3306"))
user = os.environ.get("MYSQL_USER", "root")
password = os.environ.get("MYSQL_PASSWORD", "")
database = os.environ.get("ROBOT_CODE", "")
if not database:
raise RuntimeError("环境变量 ROBOT_CODE 未配置")
return pymysql.connect(
host=host, port=port, user=user, password=password,
database=database, charset="utf8mb4",
connect_timeout=10, read_timeout=30,
)
def _query_one(conn, sql: str, params: tuple = ()) -> dict | None:
cur = conn.cursor()
cur.execute(sql, params)
columns = [desc[0] for desc in cur.description] if cur.description else []
row = cur.fetchone()
cur.close()
if row is None:
return None
return dict(zip(columns, row))
# ---------------------------------------------------------------------------
# Settings resolution (mirrors the Go service logic)
# ---------------------------------------------------------------------------
def load_drawing_settings(conn, from_wx_id: str) -> tuple[bool, dict]:
"""Return (enabled, image_ai_settings_dict)."""
# 1. global_settings
gs = _query_one(conn, "SELECT image_ai_enabled, image_ai_settings FROM global_settings LIMIT 1")
enabled = False
settings_json: dict = {}
if gs:
if gs.get("image_ai_enabled"):
enabled = bool(gs["image_ai_enabled"])
raw = gs.get("image_ai_settings")
if raw:
if isinstance(raw, (bytes, bytearray)):
raw = raw.decode("utf-8")
if isinstance(raw, str) and raw.strip():
settings_json = json.loads(raw)
# 2. override from chatroom / friend settings
if from_wx_id.endswith("@chatroom"):
override = _query_one(
conn,
"SELECT image_ai_enabled, image_ai_settings FROM chat_room_settings WHERE chat_room_id = %s LIMIT 1",
(from_wx_id,),
)
else:
override = _query_one(
conn,
"SELECT image_ai_enabled, image_ai_settings FROM friend_settings WHERE wechat_id = %s LIMIT 1",
(from_wx_id,),
)
if override:
if override.get("image_ai_enabled") is not None:
enabled = bool(override["image_ai_enabled"])
raw = override.get("image_ai_settings")
if raw:
if isinstance(raw, (bytes, bytearray)):
raw = raw.decode("utf-8")
if isinstance(raw, str) and raw.strip():
settings_json = json.loads(raw)
return enabled, settings_json
# ---------------------------------------------------------------------------
# API callers
# ---------------------------------------------------------------------------
def _http_post_json(url: str, body: dict, headers: dict, timeout: int = 300) -> dict:
data = json.dumps(body).encode("utf-8")
req = urllib.request.Request(url, data=data, headers=headers, method="POST")
with urllib.request.urlopen(req, timeout=timeout) as resp:
return json.loads(resp.read().decode("utf-8"))
def _http_get_json(url: str, headers: dict, timeout: int = 30) -> dict:
req = urllib.request.Request(url, headers=headers, method="GET")
with urllib.request.urlopen(req, timeout=timeout) as resp:
return json.loads(resp.read().decode("utf-8"))
def _coerce_int(value, default: int, minimum: int, maximum: int) -> int:
try:
parsed = int(value)
except (TypeError, ValueError):
parsed = default
return min(max(parsed, minimum), maximum)
def _openai_output_format(config: dict) -> str:
output_format = str(config.get("output_format", "png") or "png").lower()
if output_format not in {"png", "jpeg", "webp"}:
return "png"
return output_format
def _openai_size(config: dict, ratio: str, resolution: str) -> str:
configured = str(config.get("size", "") or "").strip()
if configured:
return configured
normalized_ratio = (ratio or "").replace(" ", "").lower()
normalized_resolution = (resolution or "").replace(" ", "").lower()
if normalized_resolution in {"4k", "2160p", "3840x2160"}:
sizes = {
"16:9": "3840x2160",
"9:16": "2160x3840",
"1:1": "2048x2048",
"3:2": "3072x2048",
"2:3": "2048x3072",
}
elif normalized_resolution in {"2k", "1440p", "2048"}:
sizes = {
"16:9": "2048x1152",
"9:16": "1152x2048",
"1:1": "2048x2048",
"3:2": "2048x1360",
"2:3": "1360x2048",
}
elif normalized_resolution in {"1k", "1024", "1024p"}:
sizes = {
"16:9": "1536x864",
"9:16": "864x1536",
"1:1": "1024x1024",
"3:2": "1536x1024",
"2:3": "1024x1536",
}
else:
return "auto"
return sizes.get(normalized_ratio, "auto")
def _openai_prompt(prompt: str, negative_prompt: str) -> str:
if not negative_prompt:
return prompt
return f"{prompt}\n\n不要包含: {negative_prompt}"
def _openai_client(config: dict) -> OpenAI:
api_key = str(config.get("api_key", "")).strip()
if not api_key:
raise RuntimeError("OpenAI 绘图配置缺少 api_key")
kwargs: dict[str, str | float] = {"api_key": api_key}
base_url = str(config.get("base_url", "") or "").strip()
if base_url:
kwargs["base_url"] = base_url
organization = str(config.get("organization", "") or "").strip()
if organization:
kwargs["organization"] = organization
project = str(config.get("project", "") or "").strip()
if project:
kwargs["project"] = project
timeout_value = config.get("timeout")
if timeout_value not in (None, ""):
kwargs["timeout"] = float(timeout_value)
return OpenAI(**kwargs)
def _openai_image_suffix(output_format: str) -> str:
if output_format == "jpeg":
return ".jpg"
return f".{output_format}"
def _write_openai_image_file(b64_json: str, output_format: str) -> str:
image_bytes = base64.b64decode(b64_json)
with tempfile.NamedTemporaryFile(
prefix="wechat-openai-image-",
suffix=_openai_image_suffix(output_format),
delete=False,
) as image_file:
image_file.write(image_bytes)
return image_file.name
def _openai_images_from_response(response, output_format: str) -> list[str]:
outputs: list[str] = []
for item in getattr(response, "data", []) or []:
b64_json = getattr(item, "b64_json", None)
if b64_json:
outputs.append(_write_openai_image_file(b64_json, output_format))
continue
url = getattr(item, "url", None)
if url:
outputs.append(url)
return outputs
def _is_remote_image_url(value: str) -> bool:
return urllib.parse.urlparse(value).scheme in {"http", "https"}
def _send_image_outputs(client_port: str, from_wx_id: str, image_outputs: list[str]) -> None:
remote_urls = [value for value in image_outputs if value and _is_remote_image_url(value)]
local_paths = [value for value in image_outputs if value and not _is_remote_image_url(value)]
if remote_urls:
send_url = f"http://127.0.0.1:{client_port}/api/v1/robot/message/send/image/url"
send_body = {
"to_wxid": from_wx_id,
"image_urls": remote_urls,
}
_http_post_json(send_url, send_body, {"Content-Type": "application/json"}, timeout=60)
if local_paths:
send_url = f"http://127.0.0.1:{client_port}/api/v1/robot/message/send/image/local"
send_body = {
"to_wxid": from_wx_id,
"file_path": local_paths,
}
_http_post_json(send_url, send_body, {"Content-Type": "application/json"}, timeout=60)
def _cleanup_openai_temp_files(image_outputs: list[str]) -> None:
for value in image_outputs:
path = Path(value)
if path.name.startswith("wechat-openai-image-") and path.is_file():
try:
path.unlink()
except OSError:
pass
def call_jimeng(config: dict, prompt: str, model: str,
negative_prompt: str, ratio: str, resolution: str) -> list[str]:
"""Call JiMeng (即梦) image generation API."""
base_url = config.get("base_url", "").rstrip("/")
session_ids = config.get("sessionid", [])
if not base_url or not session_ids:
raise RuntimeError("即梦绘图配置缺少 base_url 或 sessionid")
if not model or model == "none":
model = "jimeng-5.0"
if not ratio:
ratio = "16:9"
if not resolution:
resolution = "2k"
# 如果分辨率大于4k重置为2k
m = re.search(r"(\d+)", resolution)
if m and int(m.group(1)) > 4:
resolution = "2k"
token = ",".join(session_ids)
body = {
"model": model,
"prompt": prompt,
"ratio": ratio,
"resolution": resolution,
"response_format": "url",
"sample_strength": 0.5,
}
if negative_prompt:
body["negative_prompt"] = negative_prompt
resp = _http_post_json(
f"{base_url}/v1/images/generations",
body,
{"Content-Type": "application/json", "Authorization": f"Bearer {token}"},
timeout=300,
)
urls = [item["url"] for item in resp.get("data", []) if item.get("url")]
return urls
def call_doubao(config: dict, prompt: str, model: str) -> list[str]:
"""Call DouBao (豆包) image generation API."""
api_key = config.get("api_key", "")
if not api_key:
raise RuntimeError("豆包绘图配置缺少 api_key")
if not model or model == "none":
model = "doubao-seedream-4.5"
# Map friendly model names to actual endpoint model IDs
model_map = {
"doubao-seedream-4.5": "doubao-seedream-4-5-251128",
"doubao-seedream-4.0": "doubao-seedream-4-0-251128",
"doubao-seedream-3.0-t2i": "doubao-seedream-3-0-t2i-250415",
"doubao-seededit-3.0-i2i": "doubao-seededit-3-0-i2i-250628",
}
actual_model = model_map.get(model, model)
body = {
"model": actual_model,
"prompt": prompt,
"response_format": "url",
"size": config.get("size", "2K"),
"sequential_image_generation": config.get("sequential_image_generation", "auto"),
"watermark": config.get("watermark", False),
}
image_val = config.get("image", "")
if image_val:
body["image"] = image_val
resp = _http_post_json(
"https://ark.cn-beijing.volces.com/api/v3/images/generations",
body,
{"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
timeout=300,
)
urls = []
for item in resp.get("data", []):
url = item.get("url")
if url:
urls.append(url)
return urls
def call_zimage(config: dict, prompt: str, model: str) -> list[str]:
"""Call Z-Image (造相) image generation API (async task-based)."""
base_url = config.get("base_url", "").rstrip("/")
api_key = config.get("api_key", "")
if not base_url or not api_key:
raise RuntimeError("造相绘图配置缺少 base_url 或 api_key")
if not model or model == "none":
model = "Z-Image-Turbo"
# Map model names
model_map = {
"Z-Image": "Tongyi-MAI/Z-Image",
"Z-Image-Turbo": "Tongyi-MAI/Z-Image-Turbo",
"Qwen-Image-Edit-2511": "Qwen/Qwen-Image-Edit-2511",
}
actual_model = model_map.get(model)
if actual_model is None:
raise RuntimeError(f"不支持的造相模型: {model}")
body = {
"model": actual_model,
"prompt": prompt,
"image_url": config.get("image_url", []),
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
"X-ModelScope-Async-Mode": "true",
}
# Step 1: create task
resp = _http_post_json(f"{base_url}/v1/images/generations", body, headers, timeout=30)
task_id = resp.get("task_id", "")
if not task_id:
raise RuntimeError("造相接口未返回 task_id")
# Step 2: poll for result
poll_headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
"X-ModelScope-Task-Type": "image_generation",
}
deadline = time.time() + 15 * 60 # 15 minutes
while time.time() < deadline:
task_resp = _http_get_json(f"{base_url}/v1/tasks/{task_id}", poll_headers, timeout=30)
status = task_resp.get("task_status", "")
if status == "SUCCEED":
images = task_resp.get("output_images", [])
if images:
return images
raise RuntimeError("造相任务成功但未返回图片")
if status == "FAILED":
raise RuntimeError("造相绘图任务失败")
time.sleep(5)
raise RuntimeError("造相绘图任务超时")
def call_openai(config: dict, prompt: str, model: str,
negative_prompt: str, ratio: str, resolution: str) -> list[str]:
"""Call OpenAI GPT Image API for text-to-image generation."""
client = _openai_client(config)
output_format = _openai_output_format(config)
quality = str(config.get("quality", "auto") or "auto")
moderation = str(config.get("moderation", "auto") or "auto")
background = str(config.get("background", "auto") or "auto")
if background == "transparent":
background = "auto"
kwargs = {
"model": model or "gpt-image-2",
"prompt": _openai_prompt(prompt, negative_prompt),
"n": _coerce_int(config.get("n"), 1, 1, 10),
"size": _openai_size(config, ratio, resolution),
"quality": quality,
"background": background,
"moderation": moderation,
"output_format": output_format,
}
if output_format in {"jpeg", "webp"} and config.get("output_compression") is not None:
kwargs["output_compression"] = _coerce_int(config.get("output_compression"), 100, 0, 100)
response = client.images.generate(**kwargs)
return _openai_images_from_response(response, output_format)
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
JIMENG_MODELS = {"jimeng-4.5", "jimeng-4.6", "jimeng-5.0"}
DOUBAO_MODELS = {"doubao-seedream-4.5", "doubao-seedream-4.0", "doubao-seedream-3.0-t2i", "doubao-seededit-3.0-i2i"}
ZIMAGE_MODELS = {"Z-Image", "Z-Image-Turbo", "Qwen-Image-Edit-2511"}
OPENAI_MODELS = {"gpt-image-2"}
def _parse_cli_params(argv: list[str]) -> dict[str, str]:
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument("--prompt", default="")
parser.add_argument("--model", default="")
parser.add_argument("--negative_prompt", default="")
parser.add_argument("--ratio", default="")
parser.add_argument("--resolution", default="")
namespace, unknown = parser.parse_known_args(argv)
if unknown:
raise ValueError(f"存在不支持的参数: {' '.join(unknown)}")
return {
"prompt": namespace.prompt,
"model": namespace.model,
"negative_prompt": namespace.negative_prompt,
"ratio": namespace.ratio,
"resolution": namespace.resolution,
}
def main() -> int:
if len(sys.argv) < 2:
sys.stdout.write("缺少输入参数\n")
return 1
try:
params = _parse_cli_params(sys.argv[1:])
except ValueError as exc:
sys.stdout.write(f"参数格式错误: {exc}\n")
return 1
prompt = params.get("prompt", "").strip()
if not prompt:
sys.stdout.write("缺少画图提示词\n")
return 1
model = params.get("model", "").strip()
negative_prompt = params.get("negative_prompt", "").strip()
ratio = params.get("ratio", "").strip()
resolution = params.get("resolution", "").strip()
from_wx_id = os.environ.get("ROBOT_FROM_WX_ID", "").strip()
if not from_wx_id:
sys.stdout.write("环境变量 ROBOT_FROM_WX_ID 未配置\n")
return 1
# Connect to DB and load settings
try:
conn = _mysql_connect()
except Exception as exc:
sys.stdout.write(f"数据库连接失败: {exc}\n")
return 1
try:
enabled, settings_json = load_drawing_settings(conn, from_wx_id)
except Exception as exc:
conn.close()
sys.stdout.write(f"加载绘图配置失败: {exc}\n")
return 1
finally:
try:
conn.close()
except Exception:
pass
if not enabled:
sys.stdout.write("AI 绘图未开启\n")
return 0
# Default model
if not model or model == "none":
model = "jimeng-5.0"
# Route to correct API
try:
image_urls: list[str] = []
if model in JIMENG_MODELS:
jimeng_config = settings_json.get("JiMeng", {})
if not jimeng_config.get("enabled", False):
sys.stdout.write("即梦绘图未开启\n")
return 0
image_urls = call_jimeng(jimeng_config, prompt, model, negative_prompt, ratio, resolution)
elif model in DOUBAO_MODELS:
doubao_config = settings_json.get("DouBao", {})
if not doubao_config.get("enabled", False):
sys.stdout.write("豆包绘图未开启\n")
return 0
image_urls = call_doubao(doubao_config, prompt, model)
elif model in ZIMAGE_MODELS:
zimage_config = settings_json.get("Z-Image", {})
if not zimage_config.get("enabled", False):
sys.stdout.write("造相绘图未开启\n")
return 0
image_urls = call_zimage(zimage_config, prompt, model)
elif model in OPENAI_MODELS:
openai_config = settings_json.get("OpenAI", {})
if not openai_config.get("enabled", False):
sys.stdout.write("OpenAI 绘图未开启\n")
return 0
image_urls = call_openai(openai_config, prompt, model, negative_prompt, ratio, resolution)
else:
sys.stdout.write("不支持的 AI 图像模型\n")
return 1
except Exception as exc:
sys.stdout.write(f"调用绘图接口失败: {exc}\n")
return 1
if not image_urls:
sys.stdout.write("未生成任何图像\n")
return 1
# 通过客户端接口发送图片
client_port = os.environ.get("ROBOT_WECHAT_CLIENT_PORT", "").strip()
if not client_port:
_cleanup_openai_temp_files(image_urls)
sys.stdout.write("环境变量 ROBOT_WECHAT_CLIENT_PORT 未配置\n")
return 1
try:
_send_image_outputs(client_port, from_wx_id, image_urls)
sys.stdout.write("图片发送成功\n")
except Exception as exc:
sys.stdout.write(f"发送图片失败: {exc}\n")
return 1
finally:
_cleanup_openai_temp_files(image_urls)
return 0
if __name__ == "__main__":
try:
raise SystemExit(main())
except SystemExit:
raise
except Exception:
traceback.print_exc(file=sys.stdout)
raise SystemExit(1)