【百度智能云】AIGC-文生图

从地址 https://console.bce.baidu.com/ai/#/ai/intelligentwriting/app/list 创建应用, 然后复制应用的 API KeySecret Key 到下文的 API_KEY_AIGCSECRET_KEY_AIGC

import json
import os
import time  # debug
import http.client
import ssl
from urllib.parse import urlparse
import certifi

accesstoken_AIGC = None
REFRESH = 0
GETRESULT = {"imageUrl": "", "down": "", "resize": "", "pull": ""}
API_KEY_AIGC = "L7xxxxxxxxxxxxxxxxxxxxxxxxY"
SECRET_KEY_AIGC = "Bxxxxxxxxxxxxxxxxxxxxxxxxx8"
USE_BASE = True # 是否使用基础版, 否则使用高级版(费用更高)
watch_path = ""

### -------- Text2Image -- start
def Text2Image(txt):
    print("33331111")
    if len(str(txt)) == 0:
        raise Exception("解析音频为空,取消当前次请求")
    global accesstoken_AIGC
    if accesstoken_AIGC is None:
        accesstoken_AIGC = RefreshAccTokenZHY(False)
    if USE_BASE:
        url = ("https://aip.baidubce.com/rpc/2.0/ernievilg/v1/txt2img?access_token=" + accesstoken_AIGC)
        payload = json.dumps({"text": txt, "resolution": "1024*1024", "style": "写实风格"})
    else:
        url = ("https://aip.baidubce.com/rpc/2.0/ernievilg/v1/txt2imgv2?access_token=" + accesstoken_AIGC)
        payload = json.dumps({"prompt": txt, "width": 1024, "height": 1024})
    headers = {"Content-Type": "application/json", "Accept": "application/json"}
    try:
        t = time.time()
        parsed_uri = urlparse(url)
        conn = CreateHTTPSConn(parsed_uri.hostname)
        path = "/" + "/".join(url.split("/")[3:])
        conn.request("POST", path, payload, headers)
        response = conn.getresponse()
        if response.status == 200:
            getData = response.read().decode()
            result = json.loads(getData)
            print("Text2Image texttexttexttexttexttext", result)
            #  {"data": {"taskId": 18129193}, "log_id": 1735618138782398551}
            if "data" in result:
                tskId = result["data"]["taskId"]
                print("# request time %s" % (time.time() - t), tskId)
                return getUrlByTaskId(tskId)
            else:
                print("# Text2Image getimg fail!")
                raise Exception("Text2Image getimg fail!")
        else:
            print(f"Failed to Text2Image. Status code: {response.status}")
    except Exception as e:
        print(f"An error occurred: {e}")
### -------- Text2Image -- end

### -------- getUrlByTaskId -- start
def getUrlByTaskId(id):
    global accesstoken_AIGC
    if accesstoken_AIGC is None:
        accesstoken_AIGC = RefreshAccTokenZHY(False)
    if USE_BASE:
        url = ("https://aip.baidubce.com/rpc/2.0/ernievilg/v1/getImg?access_token=" + accesstoken_AIGC)
        payload = json.dumps({"taskId": str(id)})
    else:
        url = ("https://aip.baidubce.com/rpc/2.0/ernievilg/v1/getImgv2?access_token=" + accesstoken_AIGC)
        payload = json.dumps({"task_id": str(id)})
    headers = {"Content-Type": "application/json", "Accept": "application/json"}
    try:
        parsed_uri = urlparse(url)
        conn = CreateHTTPSConn(parsed_uri.hostname)
        path = "/" + "/".join(url.split("/")[3:])
        conn.request("POST", path, payload, headers)
        response = conn.getresponse()
        if response.status == 200:
            getData = response.read().decode()
            result = json.loads(getData)
            print("getUrlByTaskId text2ext", result)
            if "data" in result:
                if result["data"]["status"] == 1:
                    data = result["data"]["img"]
                    global GETRESULT
                    GETRESULT["imageUrl"] = data
                    print("# request text33333333333332ext time ", data, GETRESULT)
                    downloadImg(data)
                    return json.dumps(GETRESULT)
                else:
                    return getUrlByTaskId(id)
            else:
                return getUrlByTaskId(id)
        else:
            print(f"Failed to getUrlByTaskId. Status code: {response.status}")
    except Exception as e:
        print(f"An error occurred: {e}")
### -------- getUrlByTaskId -- end

### -------- img download and resize -- start
def downloadImg(url):
    save_path = os.path.join("dialImageO.jpg")
    parsed_uri = urlparse(url)
    import certifi, ssl

    cafile = certifi.where()
    ssl_ctx = ssl.create_default_context(cafile=cafile)
    ssl_ctx.verify_mode = ssl.CERT_REQUIRED
    connection = http.client.HTTPSConnection(
        parsed_uri.hostname, parsed_uri.port, context=ssl_ctx
    )
    connection.request("GET", url)
    response = connection.getresponse()
    if response.status == 200:
        image_data = response.read()
        with open(save_path, "wb") as file:
            file.write(image_data)
            print(f"Image downloaded and saved as {save_path}")
            GETRESULT["down"] = 1
    else:
        print(f"HTTP 请求失败: {response.status}, {response.reason}")
    connection.close()
### -------- img download and resize -- end

def RefreshAccTokenZHY():
    global REFRESH
    REFRESH = REFRESH + 1
    if REFRESH > 1:
        raise Exception("refresh ZHY token times limit!")

    url = f"https://aip.baidubce.com/oauth/2.0/token?client_id={API_KEY_AIGC}&client_secret={SECRET_KEY_AIGC}&grant_type=client_credentials"
    payload = json.dumps("")
    try:
        headers = {"Content-Type": "application/json"}
        parsed_uri = urlparse(url)
        conn = CreateHTTPSConn(parsed_uri.hostname)
        path = "/" + "/".join(url.split("/")[3:])
        conn.request("POST", path, payload, headers)
        response = conn.getresponse()
        if response.status == 200:
            getData = response.read().decode()
            result = json.loads(getData)
            if "access_token" in result:
                access_token = result["access_token"]
                print(f"ZHY access_token: {access_token}")
                return access_token
            else:
                return None
        else:
            print(f"Failed to RefreshAccToken ZHY. Status code: {response.status}")
            return None
    except Exception as e:
        print(f"RefreshAccToken ZHY An error occurred: {e}")
        return None

def CreateHTTPSConn(host):
    cafile = certifi.where()
    ctx = ssl.create_default_context(cafile=cafile)
    ctx.verify_mode = ssl.CERT_REQUIRED
    conn = http.client.HTTPSConnection(host, None, context=ctx)
    return conn

if __name__ == "__main__":
    returnObj = Text2Image("愤怒的小鸟")
    print("222>> ", returnObj)
posted @ 2024-02-22 14:00  叫夏洛啊  阅读(113)  评论(0编辑  收藏  举报