fix: 文生图、图生图支持 GPT
This commit is contained in:
parent
20ea36a340
commit
f812638e44
13
.vscode/launch.json
vendored
13
.vscode/launch.json
vendored
@ -5,20 +5,9 @@
|
|||||||
"name": "text-to-image",
|
"name": "text-to-image",
|
||||||
"type": "debugpy",
|
"type": "debugpy",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "skills/text-to-image/scripts/text_to_image.py",
|
"program": "skills/text-to-image/scripts/debug_openai_image_generation_test.py",
|
||||||
"console": "integratedTerminal",
|
"console": "integratedTerminal",
|
||||||
"justMyCode": true,
|
"justMyCode": true,
|
||||||
"args": [
|
|
||||||
"{\"prompt\":\"一只站在雨夜街头的白猫\",\"model\":\"jimeng-5.0\",\"negative_prompt\":\"模糊, 低清\",\"ratio\":\"16:9\",\"resolution\":\"2k\"}"
|
|
||||||
],
|
|
||||||
"env": {
|
|
||||||
"ROBOT_FROM_WX_ID": "57004904192@chatroom",
|
|
||||||
"ROBOT_CODE": "houhouipad",
|
|
||||||
"MYSQL_HOST": "127.0.0.1",
|
|
||||||
"MYSQL_PORT": "3306",
|
|
||||||
"MYSQL_USER": "root",
|
|
||||||
"MYSQL_PASSWORD": "houhou"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@ -236,50 +236,57 @@ def _openai_client(config: dict) -> OpenAI:
|
|||||||
if not api_key:
|
if not api_key:
|
||||||
raise RuntimeError("OpenAI 绘图配置缺少 api_key")
|
raise RuntimeError("OpenAI 绘图配置缺少 api_key")
|
||||||
|
|
||||||
kwargs: dict[str, str | float] = {"api_key": api_key}
|
|
||||||
base_url = str(config.get("base_url", "") or "").strip()
|
base_url = str(config.get("base_url", "") or "").strip()
|
||||||
if base_url:
|
|
||||||
kwargs["base_url"] = base_url
|
|
||||||
organization = str(config.get("organization", "") or "").strip()
|
organization = str(config.get("organization", "") or "").strip()
|
||||||
if organization:
|
|
||||||
kwargs["organization"] = organization
|
|
||||||
project = str(config.get("project", "") or "").strip()
|
project = str(config.get("project", "") or "").strip()
|
||||||
if project:
|
timeout: float | None = None
|
||||||
kwargs["project"] = project
|
|
||||||
timeout_value = config.get("timeout")
|
timeout_value = config.get("timeout")
|
||||||
if timeout_value not in (None, ""):
|
if timeout_value not in (None, ""):
|
||||||
kwargs["timeout"] = float(timeout_value)
|
timeout = float(timeout_value)
|
||||||
|
|
||||||
return OpenAI(**kwargs)
|
return OpenAI(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url or None,
|
||||||
|
organization=organization or None,
|
||||||
|
project=project or None,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _openai_image_suffix(output_format: str) -> str:
|
def _truncate_debug_payload(value):
|
||||||
if output_format == "jpeg":
|
if isinstance(value, dict):
|
||||||
return ".jpg"
|
return {
|
||||||
return f".{output_format}"
|
key: (
|
||||||
|
f"{item[:50]}..." if key == "b64_json" and isinstance(item, str) and len(item) > 50 else _truncate_debug_payload(item)
|
||||||
|
)
|
||||||
|
for key, item in value.items()
|
||||||
|
}
|
||||||
|
if isinstance(value, list):
|
||||||
|
return [_truncate_debug_payload(item) for item in value]
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
def _write_openai_image_file(b64_json: str, output_format: str) -> str:
|
def _debug_response(label: str, payload) -> None:
|
||||||
image_bytes = base64.b64decode(b64_json)
|
if hasattr(payload, "model_dump"):
|
||||||
with tempfile.NamedTemporaryFile(
|
payload = payload.model_dump()
|
||||||
prefix="wechat-openai-image-",
|
payload = _truncate_debug_payload(payload)
|
||||||
suffix=_openai_image_suffix(output_format),
|
sys.stdout.write(f"[debug] {label}: {json.dumps(payload, ensure_ascii=False)}\n")
|
||||||
delete=False,
|
|
||||||
) as image_file:
|
|
||||||
image_file.write(image_bytes)
|
|
||||||
return image_file.name
|
|
||||||
|
|
||||||
|
|
||||||
def _openai_images_from_response(response, output_format: str) -> list[str]:
|
def _rewrite_openai_image_url(url: str) -> str:
|
||||||
|
internal_host = "http://chatgpt2api:80"
|
||||||
|
external_host = "http://chatgpt2api.houhoukang.com"
|
||||||
|
if url.startswith(internal_host):
|
||||||
|
return f"{external_host}{url[len(internal_host):]}"
|
||||||
|
return url
|
||||||
|
|
||||||
|
|
||||||
|
def _openai_images_from_response(response) -> list[str]:
|
||||||
outputs: list[str] = []
|
outputs: list[str] = []
|
||||||
for item in getattr(response, "data", []) or []:
|
for item in getattr(response, "data", []) or []:
|
||||||
b64_json = getattr(item, "b64_json", None)
|
|
||||||
if b64_json:
|
|
||||||
outputs.append(_write_openai_image_file(b64_json, output_format))
|
|
||||||
continue
|
|
||||||
url = getattr(item, "url", None)
|
url = getattr(item, "url", None)
|
||||||
if url:
|
if url:
|
||||||
outputs.append(url)
|
outputs.append(_rewrite_openai_image_url(str(url)))
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
@ -297,7 +304,8 @@ def _send_image_outputs(client_port: str, from_wx_id: str, image_outputs: list[s
|
|||||||
"to_wxid": from_wx_id,
|
"to_wxid": from_wx_id,
|
||||||
"image_urls": remote_urls,
|
"image_urls": remote_urls,
|
||||||
}
|
}
|
||||||
_http_post_json(send_url, send_body, {"Content-Type": "application/json"}, timeout=60)
|
response = _http_post_json(send_url, send_body, {"Content-Type": "application/json"}, timeout=60)
|
||||||
|
_debug_response("send image url response", response)
|
||||||
|
|
||||||
if local_paths:
|
if local_paths:
|
||||||
send_url = f"http://127.0.0.1:{client_port}/api/v1/robot/message/send/image/local"
|
send_url = f"http://127.0.0.1:{client_port}/api/v1/robot/message/send/image/local"
|
||||||
@ -305,7 +313,8 @@ def _send_image_outputs(client_port: str, from_wx_id: str, image_outputs: list[s
|
|||||||
"to_wxid": from_wx_id,
|
"to_wxid": from_wx_id,
|
||||||
"file_path": local_paths,
|
"file_path": local_paths,
|
||||||
}
|
}
|
||||||
_http_post_json(send_url, send_body, {"Content-Type": "application/json"}, timeout=60)
|
response = _http_post_json(send_url, send_body, {"Content-Type": "application/json"}, timeout=60)
|
||||||
|
_debug_response("send image local response", response)
|
||||||
|
|
||||||
|
|
||||||
def _cleanup_openai_temp_files(image_outputs: list[str]) -> None:
|
def _cleanup_openai_temp_files(image_outputs: list[str]) -> None:
|
||||||
@ -531,7 +540,8 @@ def call_openai(config: dict, prompt: str, model: str, images: list[str],
|
|||||||
for input_file in input_files:
|
for input_file in input_files:
|
||||||
input_file.close()
|
input_file.close()
|
||||||
|
|
||||||
return _openai_images_from_response(response, output_format)
|
_debug_response("openai images.edit response", response)
|
||||||
|
return _openai_images_from_response(response)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
@ -3,13 +3,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import base64
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
@ -237,50 +235,57 @@ def _openai_client(config: dict) -> OpenAI:
|
|||||||
if not api_key:
|
if not api_key:
|
||||||
raise RuntimeError("OpenAI 绘图配置缺少 api_key")
|
raise RuntimeError("OpenAI 绘图配置缺少 api_key")
|
||||||
|
|
||||||
kwargs: dict[str, str | float] = {"api_key": api_key}
|
|
||||||
base_url = str(config.get("base_url", "") or "").strip()
|
base_url = str(config.get("base_url", "") or "").strip()
|
||||||
if base_url:
|
|
||||||
kwargs["base_url"] = base_url
|
|
||||||
organization = str(config.get("organization", "") or "").strip()
|
organization = str(config.get("organization", "") or "").strip()
|
||||||
if organization:
|
|
||||||
kwargs["organization"] = organization
|
|
||||||
project = str(config.get("project", "") or "").strip()
|
project = str(config.get("project", "") or "").strip()
|
||||||
if project:
|
timeout: float | None = None
|
||||||
kwargs["project"] = project
|
|
||||||
timeout_value = config.get("timeout")
|
timeout_value = config.get("timeout")
|
||||||
if timeout_value not in (None, ""):
|
if timeout_value not in (None, ""):
|
||||||
kwargs["timeout"] = float(timeout_value)
|
timeout = float(timeout_value)
|
||||||
|
|
||||||
return OpenAI(**kwargs)
|
return OpenAI(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url or None,
|
||||||
|
organization=organization or None,
|
||||||
|
project=project or None,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _openai_image_suffix(output_format: str) -> str:
|
def _truncate_debug_payload(value):
|
||||||
if output_format == "jpeg":
|
if isinstance(value, dict):
|
||||||
return ".jpg"
|
return {
|
||||||
return f".{output_format}"
|
key: (
|
||||||
|
f"{item[:50]}..." if key == "b64_json" and isinstance(item, str) and len(item) > 50 else _truncate_debug_payload(item)
|
||||||
|
)
|
||||||
|
for key, item in value.items()
|
||||||
|
}
|
||||||
|
if isinstance(value, list):
|
||||||
|
return [_truncate_debug_payload(item) for item in value]
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
def _write_openai_image_file(b64_json: str, output_format: str) -> str:
|
def _debug_response(label: str, payload) -> None:
|
||||||
image_bytes = base64.b64decode(b64_json)
|
if hasattr(payload, "model_dump"):
|
||||||
with tempfile.NamedTemporaryFile(
|
payload = payload.model_dump()
|
||||||
prefix="wechat-openai-image-",
|
payload = _truncate_debug_payload(payload)
|
||||||
suffix=_openai_image_suffix(output_format),
|
sys.stdout.write(f"[debug] {label}: {json.dumps(payload, ensure_ascii=False)}\n")
|
||||||
delete=False,
|
|
||||||
) as image_file:
|
|
||||||
image_file.write(image_bytes)
|
|
||||||
return image_file.name
|
|
||||||
|
|
||||||
|
|
||||||
def _openai_images_from_response(response, output_format: str) -> list[str]:
|
def _rewrite_openai_image_url(url: str) -> str:
|
||||||
|
internal_host = "http://chatgpt2api:80"
|
||||||
|
external_host = "http://chatgpt2api.houhoukang.com"
|
||||||
|
if url.startswith(internal_host):
|
||||||
|
return f"{external_host}{url[len(internal_host):]}"
|
||||||
|
return url
|
||||||
|
|
||||||
|
|
||||||
|
def _openai_images_from_response(response) -> list[str]:
|
||||||
outputs: list[str] = []
|
outputs: list[str] = []
|
||||||
for item in getattr(response, "data", []) or []:
|
for item in getattr(response, "data", []) or []:
|
||||||
b64_json = getattr(item, "b64_json", None)
|
|
||||||
if b64_json:
|
|
||||||
outputs.append(_write_openai_image_file(b64_json, output_format))
|
|
||||||
continue
|
|
||||||
url = getattr(item, "url", None)
|
url = getattr(item, "url", None)
|
||||||
if url:
|
if url:
|
||||||
outputs.append(url)
|
outputs.append(_rewrite_openai_image_url(str(url)))
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
@ -298,7 +303,8 @@ def _send_image_outputs(client_port: str, from_wx_id: str, image_outputs: list[s
|
|||||||
"to_wxid": from_wx_id,
|
"to_wxid": from_wx_id,
|
||||||
"image_urls": remote_urls,
|
"image_urls": remote_urls,
|
||||||
}
|
}
|
||||||
_http_post_json(send_url, send_body, {"Content-Type": "application/json"}, timeout=60)
|
response = _http_post_json(send_url, send_body, {"Content-Type": "application/json"}, timeout=60)
|
||||||
|
_debug_response("send image url response", response)
|
||||||
|
|
||||||
if local_paths:
|
if local_paths:
|
||||||
send_url = f"http://127.0.0.1:{client_port}/api/v1/robot/message/send/image/local"
|
send_url = f"http://127.0.0.1:{client_port}/api/v1/robot/message/send/image/local"
|
||||||
@ -306,7 +312,8 @@ def _send_image_outputs(client_port: str, from_wx_id: str, image_outputs: list[s
|
|||||||
"to_wxid": from_wx_id,
|
"to_wxid": from_wx_id,
|
||||||
"file_path": local_paths,
|
"file_path": local_paths,
|
||||||
}
|
}
|
||||||
_http_post_json(send_url, send_body, {"Content-Type": "application/json"}, timeout=60)
|
response = _http_post_json(send_url, send_body, {"Content-Type": "application/json"}, timeout=60)
|
||||||
|
_debug_response("send image local response", response)
|
||||||
|
|
||||||
|
|
||||||
def _cleanup_openai_temp_files(image_outputs: list[str]) -> None:
|
def _cleanup_openai_temp_files(image_outputs: list[str]) -> None:
|
||||||
@ -490,7 +497,8 @@ def call_openai(config: dict, prompt: str, model: str,
|
|||||||
kwargs["output_compression"] = _coerce_int(config.get("output_compression"), 100, 0, 100)
|
kwargs["output_compression"] = _coerce_int(config.get("output_compression"), 100, 0, 100)
|
||||||
|
|
||||||
response = client.images.generate(**kwargs)
|
response = client.images.generate(**kwargs)
|
||||||
return _openai_images_from_response(response, output_format)
|
_debug_response("openai images.generate response", response)
|
||||||
|
return _openai_images_from_response(response)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user