fix: 文生图、图生图支持 GPT
This commit is contained in:
parent
20ea36a340
commit
f812638e44
13
.vscode/launch.json
vendored
13
.vscode/launch.json
vendored
@ -5,20 +5,9 @@
|
||||
"name": "text-to-image",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "skills/text-to-image/scripts/text_to_image.py",
|
||||
"program": "skills/text-to-image/scripts/debug_openai_image_generation_test.py",
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": true,
|
||||
"args": [
|
||||
"{\"prompt\":\"一只站在雨夜街头的白猫\",\"model\":\"jimeng-5.0\",\"negative_prompt\":\"模糊, 低清\",\"ratio\":\"16:9\",\"resolution\":\"2k\"}"
|
||||
],
|
||||
"env": {
|
||||
"ROBOT_FROM_WX_ID": "57004904192@chatroom",
|
||||
"ROBOT_CODE": "houhouipad",
|
||||
"MYSQL_HOST": "127.0.0.1",
|
||||
"MYSQL_PORT": "3306",
|
||||
"MYSQL_USER": "root",
|
||||
"MYSQL_PASSWORD": "houhou"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
@ -236,50 +236,57 @@ def _openai_client(config: dict) -> OpenAI:
|
||||
if not api_key:
|
||||
raise RuntimeError("OpenAI 绘图配置缺少 api_key")
|
||||
|
||||
kwargs: dict[str, str | float] = {"api_key": api_key}
|
||||
base_url = str(config.get("base_url", "") or "").strip()
|
||||
if base_url:
|
||||
kwargs["base_url"] = base_url
|
||||
organization = str(config.get("organization", "") or "").strip()
|
||||
if organization:
|
||||
kwargs["organization"] = organization
|
||||
project = str(config.get("project", "") or "").strip()
|
||||
if project:
|
||||
kwargs["project"] = project
|
||||
timeout: float | None = None
|
||||
timeout_value = config.get("timeout")
|
||||
if timeout_value not in (None, ""):
|
||||
kwargs["timeout"] = float(timeout_value)
|
||||
timeout = float(timeout_value)
|
||||
|
||||
return OpenAI(**kwargs)
|
||||
return OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url or None,
|
||||
organization=organization or None,
|
||||
project=project or None,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
|
||||
def _openai_image_suffix(output_format: str) -> str:
|
||||
if output_format == "jpeg":
|
||||
return ".jpg"
|
||||
return f".{output_format}"
|
||||
def _truncate_debug_payload(value):
|
||||
if isinstance(value, dict):
|
||||
return {
|
||||
key: (
|
||||
f"{item[:50]}..." if key == "b64_json" and isinstance(item, str) and len(item) > 50 else _truncate_debug_payload(item)
|
||||
)
|
||||
for key, item in value.items()
|
||||
}
|
||||
if isinstance(value, list):
|
||||
return [_truncate_debug_payload(item) for item in value]
|
||||
return value
|
||||
|
||||
|
||||
def _write_openai_image_file(b64_json: str, output_format: str) -> str:
|
||||
image_bytes = base64.b64decode(b64_json)
|
||||
with tempfile.NamedTemporaryFile(
|
||||
prefix="wechat-openai-image-",
|
||||
suffix=_openai_image_suffix(output_format),
|
||||
delete=False,
|
||||
) as image_file:
|
||||
image_file.write(image_bytes)
|
||||
return image_file.name
|
||||
def _debug_response(label: str, payload) -> None:
|
||||
if hasattr(payload, "model_dump"):
|
||||
payload = payload.model_dump()
|
||||
payload = _truncate_debug_payload(payload)
|
||||
sys.stdout.write(f"[debug] {label}: {json.dumps(payload, ensure_ascii=False)}\n")
|
||||
|
||||
|
||||
def _openai_images_from_response(response, output_format: str) -> list[str]:
|
||||
def _rewrite_openai_image_url(url: str) -> str:
|
||||
internal_host = "http://chatgpt2api:80"
|
||||
external_host = "http://chatgpt2api.houhoukang.com"
|
||||
if url.startswith(internal_host):
|
||||
return f"{external_host}{url[len(internal_host):]}"
|
||||
return url
|
||||
|
||||
|
||||
def _openai_images_from_response(response) -> list[str]:
|
||||
outputs: list[str] = []
|
||||
for item in getattr(response, "data", []) or []:
|
||||
b64_json = getattr(item, "b64_json", None)
|
||||
if b64_json:
|
||||
outputs.append(_write_openai_image_file(b64_json, output_format))
|
||||
continue
|
||||
url = getattr(item, "url", None)
|
||||
if url:
|
||||
outputs.append(url)
|
||||
outputs.append(_rewrite_openai_image_url(str(url)))
|
||||
return outputs
|
||||
|
||||
|
||||
@ -297,7 +304,8 @@ def _send_image_outputs(client_port: str, from_wx_id: str, image_outputs: list[s
|
||||
"to_wxid": from_wx_id,
|
||||
"image_urls": remote_urls,
|
||||
}
|
||||
_http_post_json(send_url, send_body, {"Content-Type": "application/json"}, timeout=60)
|
||||
response = _http_post_json(send_url, send_body, {"Content-Type": "application/json"}, timeout=60)
|
||||
_debug_response("send image url response", response)
|
||||
|
||||
if local_paths:
|
||||
send_url = f"http://127.0.0.1:{client_port}/api/v1/robot/message/send/image/local"
|
||||
@ -305,7 +313,8 @@ def _send_image_outputs(client_port: str, from_wx_id: str, image_outputs: list[s
|
||||
"to_wxid": from_wx_id,
|
||||
"file_path": local_paths,
|
||||
}
|
||||
_http_post_json(send_url, send_body, {"Content-Type": "application/json"}, timeout=60)
|
||||
response = _http_post_json(send_url, send_body, {"Content-Type": "application/json"}, timeout=60)
|
||||
_debug_response("send image local response", response)
|
||||
|
||||
|
||||
def _cleanup_openai_temp_files(image_outputs: list[str]) -> None:
|
||||
@ -531,7 +540,8 @@ def call_openai(config: dict, prompt: str, model: str, images: list[str],
|
||||
for input_file in input_files:
|
||||
input_file.close()
|
||||
|
||||
return _openai_images_from_response(response, output_format)
|
||||
_debug_response("openai images.edit response", response)
|
||||
return _openai_images_from_response(response)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@ -3,13 +3,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import traceback
|
||||
import urllib.parse
|
||||
@ -237,50 +235,57 @@ def _openai_client(config: dict) -> OpenAI:
|
||||
if not api_key:
|
||||
raise RuntimeError("OpenAI 绘图配置缺少 api_key")
|
||||
|
||||
kwargs: dict[str, str | float] = {"api_key": api_key}
|
||||
base_url = str(config.get("base_url", "") or "").strip()
|
||||
if base_url:
|
||||
kwargs["base_url"] = base_url
|
||||
organization = str(config.get("organization", "") or "").strip()
|
||||
if organization:
|
||||
kwargs["organization"] = organization
|
||||
project = str(config.get("project", "") or "").strip()
|
||||
if project:
|
||||
kwargs["project"] = project
|
||||
timeout: float | None = None
|
||||
timeout_value = config.get("timeout")
|
||||
if timeout_value not in (None, ""):
|
||||
kwargs["timeout"] = float(timeout_value)
|
||||
timeout = float(timeout_value)
|
||||
|
||||
return OpenAI(**kwargs)
|
||||
return OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url or None,
|
||||
organization=organization or None,
|
||||
project=project or None,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
|
||||
def _openai_image_suffix(output_format: str) -> str:
|
||||
if output_format == "jpeg":
|
||||
return ".jpg"
|
||||
return f".{output_format}"
|
||||
def _truncate_debug_payload(value):
|
||||
if isinstance(value, dict):
|
||||
return {
|
||||
key: (
|
||||
f"{item[:50]}..." if key == "b64_json" and isinstance(item, str) and len(item) > 50 else _truncate_debug_payload(item)
|
||||
)
|
||||
for key, item in value.items()
|
||||
}
|
||||
if isinstance(value, list):
|
||||
return [_truncate_debug_payload(item) for item in value]
|
||||
return value
|
||||
|
||||
|
||||
def _write_openai_image_file(b64_json: str, output_format: str) -> str:
|
||||
image_bytes = base64.b64decode(b64_json)
|
||||
with tempfile.NamedTemporaryFile(
|
||||
prefix="wechat-openai-image-",
|
||||
suffix=_openai_image_suffix(output_format),
|
||||
delete=False,
|
||||
) as image_file:
|
||||
image_file.write(image_bytes)
|
||||
return image_file.name
|
||||
def _debug_response(label: str, payload) -> None:
|
||||
if hasattr(payload, "model_dump"):
|
||||
payload = payload.model_dump()
|
||||
payload = _truncate_debug_payload(payload)
|
||||
sys.stdout.write(f"[debug] {label}: {json.dumps(payload, ensure_ascii=False)}\n")
|
||||
|
||||
|
||||
def _openai_images_from_response(response, output_format: str) -> list[str]:
|
||||
def _rewrite_openai_image_url(url: str) -> str:
|
||||
internal_host = "http://chatgpt2api:80"
|
||||
external_host = "http://chatgpt2api.houhoukang.com"
|
||||
if url.startswith(internal_host):
|
||||
return f"{external_host}{url[len(internal_host):]}"
|
||||
return url
|
||||
|
||||
|
||||
def _openai_images_from_response(response) -> list[str]:
|
||||
outputs: list[str] = []
|
||||
for item in getattr(response, "data", []) or []:
|
||||
b64_json = getattr(item, "b64_json", None)
|
||||
if b64_json:
|
||||
outputs.append(_write_openai_image_file(b64_json, output_format))
|
||||
continue
|
||||
url = getattr(item, "url", None)
|
||||
if url:
|
||||
outputs.append(url)
|
||||
outputs.append(_rewrite_openai_image_url(str(url)))
|
||||
return outputs
|
||||
|
||||
|
||||
@ -298,7 +303,8 @@ def _send_image_outputs(client_port: str, from_wx_id: str, image_outputs: list[s
|
||||
"to_wxid": from_wx_id,
|
||||
"image_urls": remote_urls,
|
||||
}
|
||||
_http_post_json(send_url, send_body, {"Content-Type": "application/json"}, timeout=60)
|
||||
response = _http_post_json(send_url, send_body, {"Content-Type": "application/json"}, timeout=60)
|
||||
_debug_response("send image url response", response)
|
||||
|
||||
if local_paths:
|
||||
send_url = f"http://127.0.0.1:{client_port}/api/v1/robot/message/send/image/local"
|
||||
@ -306,7 +312,8 @@ def _send_image_outputs(client_port: str, from_wx_id: str, image_outputs: list[s
|
||||
"to_wxid": from_wx_id,
|
||||
"file_path": local_paths,
|
||||
}
|
||||
_http_post_json(send_url, send_body, {"Content-Type": "application/json"}, timeout=60)
|
||||
response = _http_post_json(send_url, send_body, {"Content-Type": "application/json"}, timeout=60)
|
||||
_debug_response("send image local response", response)
|
||||
|
||||
|
||||
def _cleanup_openai_temp_files(image_outputs: list[str]) -> None:
|
||||
@ -490,7 +497,8 @@ def call_openai(config: dict, prompt: str, model: str,
|
||||
kwargs["output_compression"] = _coerce_int(config.get("output_compression"), 100, 0, 100)
|
||||
|
||||
response = client.images.generate(**kwargs)
|
||||
return _openai_images_from_response(response, output_format)
|
||||
_debug_response("openai images.generate response", response)
|
||||
return _openai_images_from_response(response)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
Loading…
Reference in New Issue
Block a user