fix: 文生图、图生图支持 GPT

This commit is contained in:
hp0912 2026-05-05 23:08:42 +08:00
parent 20ea36a340
commit f812638e44
3 changed files with 83 additions and 76 deletions

13
.vscode/launch.json vendored
View File

@ -5,20 +5,9 @@
"name": "text-to-image",
"type": "debugpy",
"request": "launch",
"program": "skills/text-to-image/scripts/text_to_image.py",
"program": "skills/text-to-image/scripts/debug_openai_image_generation_test.py",
"console": "integratedTerminal",
"justMyCode": true,
"args": [
"{\"prompt\":\"一只站在雨夜街头的白猫\",\"model\":\"jimeng-5.0\",\"negative_prompt\":\"模糊, 低清\",\"ratio\":\"16:9\",\"resolution\":\"2k\"}"
],
"env": {
"ROBOT_FROM_WX_ID": "57004904192@chatroom",
"ROBOT_CODE": "houhouipad",
"MYSQL_HOST": "127.0.0.1",
"MYSQL_PORT": "3306",
"MYSQL_USER": "root",
"MYSQL_PASSWORD": "houhou"
}
}
]
}

View File

@ -236,50 +236,57 @@ def _openai_client(config: dict) -> OpenAI:
if not api_key:
raise RuntimeError("OpenAI 绘图配置缺少 api_key")
kwargs: dict[str, str | float] = {"api_key": api_key}
base_url = str(config.get("base_url", "") or "").strip()
if base_url:
kwargs["base_url"] = base_url
organization = str(config.get("organization", "") or "").strip()
if organization:
kwargs["organization"] = organization
project = str(config.get("project", "") or "").strip()
if project:
kwargs["project"] = project
timeout: float | None = None
timeout_value = config.get("timeout")
if timeout_value not in (None, ""):
kwargs["timeout"] = float(timeout_value)
timeout = float(timeout_value)
return OpenAI(**kwargs)
return OpenAI(
api_key=api_key,
base_url=base_url or None,
organization=organization or None,
project=project or None,
timeout=timeout,
)
def _openai_image_suffix(output_format: str) -> str:
if output_format == "jpeg":
return ".jpg"
return f".{output_format}"
def _truncate_debug_payload(value):
if isinstance(value, dict):
return {
key: (
f"{item[:50]}..." if key == "b64_json" and isinstance(item, str) and len(item) > 50 else _truncate_debug_payload(item)
)
for key, item in value.items()
}
if isinstance(value, list):
return [_truncate_debug_payload(item) for item in value]
return value
def _write_openai_image_file(b64_json: str, output_format: str) -> str:
image_bytes = base64.b64decode(b64_json)
with tempfile.NamedTemporaryFile(
prefix="wechat-openai-image-",
suffix=_openai_image_suffix(output_format),
delete=False,
) as image_file:
image_file.write(image_bytes)
return image_file.name
def _debug_response(label: str, payload) -> None:
if hasattr(payload, "model_dump"):
payload = payload.model_dump()
payload = _truncate_debug_payload(payload)
sys.stdout.write(f"[debug] {label}: {json.dumps(payload, ensure_ascii=False)}\n")
def _openai_images_from_response(response, output_format: str) -> list[str]:
def _rewrite_openai_image_url(url: str) -> str:
internal_host = "http://chatgpt2api:80"
external_host = "http://chatgpt2api.houhoukang.com"
if url.startswith(internal_host):
return f"{external_host}{url[len(internal_host):]}"
return url
def _openai_images_from_response(response) -> list[str]:
outputs: list[str] = []
for item in getattr(response, "data", []) or []:
b64_json = getattr(item, "b64_json", None)
if b64_json:
outputs.append(_write_openai_image_file(b64_json, output_format))
continue
url = getattr(item, "url", None)
if url:
outputs.append(url)
outputs.append(_rewrite_openai_image_url(str(url)))
return outputs
@ -297,7 +304,8 @@ def _send_image_outputs(client_port: str, from_wx_id: str, image_outputs: list[s
"to_wxid": from_wx_id,
"image_urls": remote_urls,
}
_http_post_json(send_url, send_body, {"Content-Type": "application/json"}, timeout=60)
response = _http_post_json(send_url, send_body, {"Content-Type": "application/json"}, timeout=60)
_debug_response("send image url response", response)
if local_paths:
send_url = f"http://127.0.0.1:{client_port}/api/v1/robot/message/send/image/local"
@ -305,7 +313,8 @@ def _send_image_outputs(client_port: str, from_wx_id: str, image_outputs: list[s
"to_wxid": from_wx_id,
"file_path": local_paths,
}
_http_post_json(send_url, send_body, {"Content-Type": "application/json"}, timeout=60)
response = _http_post_json(send_url, send_body, {"Content-Type": "application/json"}, timeout=60)
_debug_response("send image local response", response)
def _cleanup_openai_temp_files(image_outputs: list[str]) -> None:
@ -531,7 +540,8 @@ def call_openai(config: dict, prompt: str, model: str, images: list[str],
for input_file in input_files:
input_file.close()
return _openai_images_from_response(response, output_format)
_debug_response("openai images.edit response", response)
return _openai_images_from_response(response)
# ---------------------------------------------------------------------------

View File

@ -3,13 +3,11 @@
from __future__ import annotations
import argparse
import base64
import json
import os
import re
import subprocess
import sys
import tempfile
import time
import traceback
import urllib.parse
@ -237,50 +235,57 @@ def _openai_client(config: dict) -> OpenAI:
if not api_key:
raise RuntimeError("OpenAI 绘图配置缺少 api_key")
kwargs: dict[str, str | float] = {"api_key": api_key}
base_url = str(config.get("base_url", "") or "").strip()
if base_url:
kwargs["base_url"] = base_url
organization = str(config.get("organization", "") or "").strip()
if organization:
kwargs["organization"] = organization
project = str(config.get("project", "") or "").strip()
if project:
kwargs["project"] = project
timeout: float | None = None
timeout_value = config.get("timeout")
if timeout_value not in (None, ""):
kwargs["timeout"] = float(timeout_value)
timeout = float(timeout_value)
return OpenAI(**kwargs)
return OpenAI(
api_key=api_key,
base_url=base_url or None,
organization=organization or None,
project=project or None,
timeout=timeout,
)
def _openai_image_suffix(output_format: str) -> str:
if output_format == "jpeg":
return ".jpg"
return f".{output_format}"
def _truncate_debug_payload(value):
if isinstance(value, dict):
return {
key: (
f"{item[:50]}..." if key == "b64_json" and isinstance(item, str) and len(item) > 50 else _truncate_debug_payload(item)
)
for key, item in value.items()
}
if isinstance(value, list):
return [_truncate_debug_payload(item) for item in value]
return value
def _write_openai_image_file(b64_json: str, output_format: str) -> str:
image_bytes = base64.b64decode(b64_json)
with tempfile.NamedTemporaryFile(
prefix="wechat-openai-image-",
suffix=_openai_image_suffix(output_format),
delete=False,
) as image_file:
image_file.write(image_bytes)
return image_file.name
def _debug_response(label: str, payload) -> None:
if hasattr(payload, "model_dump"):
payload = payload.model_dump()
payload = _truncate_debug_payload(payload)
sys.stdout.write(f"[debug] {label}: {json.dumps(payload, ensure_ascii=False)}\n")
def _openai_images_from_response(response, output_format: str) -> list[str]:
def _rewrite_openai_image_url(url: str) -> str:
internal_host = "http://chatgpt2api:80"
external_host = "http://chatgpt2api.houhoukang.com"
if url.startswith(internal_host):
return f"{external_host}{url[len(internal_host):]}"
return url
def _openai_images_from_response(response) -> list[str]:
outputs: list[str] = []
for item in getattr(response, "data", []) or []:
b64_json = getattr(item, "b64_json", None)
if b64_json:
outputs.append(_write_openai_image_file(b64_json, output_format))
continue
url = getattr(item, "url", None)
if url:
outputs.append(url)
outputs.append(_rewrite_openai_image_url(str(url)))
return outputs
@ -298,7 +303,8 @@ def _send_image_outputs(client_port: str, from_wx_id: str, image_outputs: list[s
"to_wxid": from_wx_id,
"image_urls": remote_urls,
}
_http_post_json(send_url, send_body, {"Content-Type": "application/json"}, timeout=60)
response = _http_post_json(send_url, send_body, {"Content-Type": "application/json"}, timeout=60)
_debug_response("send image url response", response)
if local_paths:
send_url = f"http://127.0.0.1:{client_port}/api/v1/robot/message/send/image/local"
@ -306,7 +312,8 @@ def _send_image_outputs(client_port: str, from_wx_id: str, image_outputs: list[s
"to_wxid": from_wx_id,
"file_path": local_paths,
}
_http_post_json(send_url, send_body, {"Content-Type": "application/json"}, timeout=60)
response = _http_post_json(send_url, send_body, {"Content-Type": "application/json"}, timeout=60)
_debug_response("send image local response", response)
def _cleanup_openai_temp_files(image_outputs: list[str]) -> None:
@ -490,7 +497,8 @@ def call_openai(config: dict, prompt: str, model: str,
kwargs["output_compression"] = _coerce_int(config.get("output_compression"), 100, 0, 100)
response = client.images.generate(**kwargs)
return _openai_images_from_response(response, output_format)
_debug_response("openai images.generate response", response)
return _openai_images_from_response(response)
# ---------------------------------------------------------------------------