fix: 优化发送图片逻辑
This commit is contained in:
parent
569533e9d9
commit
0a14269187
@ -281,12 +281,55 @@ def _rewrite_openai_image_url(url: str) -> str:
|
||||
return url
|
||||
|
||||
|
||||
def _openai_images_from_response(response) -> list[str]:
|
||||
def _extension_from_output_format(output_format: str) -> str:
|
||||
if output_format == "jpeg":
|
||||
return ".jpg"
|
||||
if output_format == "webp":
|
||||
return ".webp"
|
||||
return ".png"
|
||||
|
||||
|
||||
def _openai_response_value(item, key: str):
|
||||
if isinstance(item, dict):
|
||||
return item.get(key)
|
||||
return getattr(item, key, None)
|
||||
|
||||
|
||||
def _write_openai_b64_image(b64_json: str, output_format: str) -> str:
|
||||
encoded = b64_json.strip()
|
||||
suffix = _extension_from_output_format(output_format)
|
||||
if encoded.startswith("data:"):
|
||||
header, encoded = encoded.split(",", 1)
|
||||
mime_type = header[5:].split(";", 1)[0].strip().lower()
|
||||
if mime_type:
|
||||
suffix = _extension_from_mime(mime_type)
|
||||
|
||||
encoded = "".join(encoded.split())
|
||||
padding = len(encoded) % 4
|
||||
if padding:
|
||||
encoded = f"{encoded}{'=' * (4 - padding)}"
|
||||
|
||||
image_bytes = base64.b64decode(encoded)
|
||||
with tempfile.NamedTemporaryFile(prefix="wechat-openai-image-", suffix=suffix, delete=False) as temp_file:
|
||||
temp_file.write(image_bytes)
|
||||
return temp_file.name
|
||||
|
||||
|
||||
def _openai_images_from_response(response, output_format: str) -> list[str]:
|
||||
outputs: list[str] = []
|
||||
for item in getattr(response, "data", []) or []:
|
||||
url = getattr(item, "url", None)
|
||||
if url:
|
||||
outputs.append(_rewrite_openai_image_url(str(url)))
|
||||
try:
|
||||
for item in getattr(response, "data", []) or []:
|
||||
b64_json = _openai_response_value(item, "b64_json")
|
||||
if b64_json:
|
||||
outputs.append(_write_openai_b64_image(str(b64_json), output_format))
|
||||
continue
|
||||
|
||||
url = _openai_response_value(item, "url")
|
||||
if url:
|
||||
outputs.append(_rewrite_openai_image_url(str(url)))
|
||||
except Exception:
|
||||
_cleanup_openai_temp_files(outputs)
|
||||
raise
|
||||
return outputs
|
||||
|
||||
|
||||
@ -541,7 +584,7 @@ def call_openai(config: dict, prompt: str, model: str, images: list[str],
|
||||
input_file.close()
|
||||
|
||||
_debug_response("openai images.edit response", response)
|
||||
return _openai_images_from_response(response)
|
||||
return _openai_images_from_response(response, output_format)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@ -3,11 +3,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import traceback
|
||||
import urllib.parse
|
||||
@ -280,12 +283,64 @@ def _rewrite_openai_image_url(url: str) -> str:
|
||||
return url
|
||||
|
||||
|
||||
def _openai_images_from_response(response) -> list[str]:
|
||||
def _extension_from_mime(mime_type: str) -> str:
|
||||
if mime_type == "image/jpeg":
|
||||
return ".jpg"
|
||||
guessed = mimetypes.guess_extension(mime_type)
|
||||
if guessed in {".png", ".jpg", ".jpeg", ".webp"}:
|
||||
return guessed
|
||||
return ".png"
|
||||
|
||||
|
||||
def _extension_from_output_format(output_format: str) -> str:
|
||||
if output_format == "jpeg":
|
||||
return ".jpg"
|
||||
if output_format == "webp":
|
||||
return ".webp"
|
||||
return ".png"
|
||||
|
||||
|
||||
def _openai_response_value(item, key: str):
|
||||
if isinstance(item, dict):
|
||||
return item.get(key)
|
||||
return getattr(item, key, None)
|
||||
|
||||
|
||||
def _write_openai_b64_image(b64_json: str, output_format: str) -> str:
|
||||
encoded = b64_json.strip()
|
||||
suffix = _extension_from_output_format(output_format)
|
||||
if encoded.startswith("data:"):
|
||||
header, encoded = encoded.split(",", 1)
|
||||
mime_type = header[5:].split(";", 1)[0].strip().lower()
|
||||
if mime_type:
|
||||
suffix = _extension_from_mime(mime_type)
|
||||
|
||||
encoded = "".join(encoded.split())
|
||||
padding = len(encoded) % 4
|
||||
if padding:
|
||||
encoded = f"{encoded}{'=' * (4 - padding)}"
|
||||
|
||||
image_bytes = base64.b64decode(encoded)
|
||||
with tempfile.NamedTemporaryFile(prefix="wechat-openai-image-", suffix=suffix, delete=False) as temp_file:
|
||||
temp_file.write(image_bytes)
|
||||
return temp_file.name
|
||||
|
||||
|
||||
def _openai_images_from_response(response, output_format: str) -> list[str]:
|
||||
outputs: list[str] = []
|
||||
for item in getattr(response, "data", []) or []:
|
||||
url = getattr(item, "url", None)
|
||||
if url:
|
||||
outputs.append(_rewrite_openai_image_url(str(url)))
|
||||
try:
|
||||
for item in getattr(response, "data", []) or []:
|
||||
b64_json = _openai_response_value(item, "b64_json")
|
||||
if b64_json:
|
||||
outputs.append(_write_openai_b64_image(str(b64_json), output_format))
|
||||
continue
|
||||
|
||||
url = _openai_response_value(item, "url")
|
||||
if url:
|
||||
outputs.append(_rewrite_openai_image_url(str(url)))
|
||||
except Exception:
|
||||
_cleanup_openai_temp_files(outputs)
|
||||
raise
|
||||
return outputs
|
||||
|
||||
|
||||
@ -498,7 +553,7 @@ def call_openai(config: dict, prompt: str, model: str,
|
||||
|
||||
response = client.images.generate(**kwargs)
|
||||
_debug_response("openai images.generate response", response)
|
||||
return _openai_images_from_response(response)
|
||||
return _openai_images_from_response(response, output_format)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
Loading…
Reference in New Issue
Block a user