import torch
# from torch import autocast
from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler,AutoencoderKL,DPMSolverMultistepScheduler,StableDiffusionLatentUpscalePipeline
import matplotlib.pyplot as plt
from PIL import Image
import time
from diffusers import DiffusionPipeline
import os
import sys
import requests
import utils
import json
import torch


def notify_discord(output):
    config = utils.get_config()
    discord_webhook_url = config['discord_webhook_url'] #発行したトークンを入力

    data = {"content": "img"}
    requests.post(
        discord_webhook_url,
        data=data,
        files={'file': open(output, 'rb')}
    )

def notify_line(putput):

    config = utils.get_config()
    line_notify_token = config['line_notify_token'] #発行したトークンを入力

    line_notify_api = "https://notify-api.line.me/api/notify"

    current_time = time.strftime("%Y%m%d%H%M%S")

    payload = {"message":current_time}  #メッセージの本文
    # payload = {}  #メッセージの本文
    headers = {"Authorization":"Bearer " + line_notify_token}
    requests.post(
        line_notify_api,
        data = payload,
        headers = headers,
        files={'imageFile': open(output, 'rb')}
    )




if __name__ == '__main__':

    # SendLineMessage("Start")
    # notify_discord("Start")
    # sys.exit()
    # cuDNN 自動チューナーを有効にする
    # torch.backends.cudnn.benchmark = True

    # fp32 の代わりに tf32 を使用する (Ampere とそれ以後の CUDA デバイス上で)
    # torch.backends.cuda.matmul.allow_tf32 = True


    base_dir = os.path.join('C:' + os.sep,'dev','ai','stable-diffusion-webui-forge')
    lora_dir = os.path.join(base_dir, 'models', 'lora')
    model_dir = os.path.join(base_dir, 'models', 'diffusers')
    vae_dir = os.path.join(base_dir, 'models', 'VAE')

    config = utils.get_config()
    headers = {
        # "User-Agent": "user-agent=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36",
        "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/116.0.0.0 Safari/537.36",
    }


    width =config['width']
    height = config['height']
    create_count = config['create_count']



    url = config['check_url']
    r = requests.get(url, headers=headers,verify=False)
    json_data = r.json()

    # print(json.dumps(json_data, indent=2))
    # sys.exit()

    for data in json_data:


        for i in range(create_count):

            try:

                model = data['model']
                lora = data['lora']
                prompt = data['prompt']
                negative_prompt = data['negative_prompt']
                if(model.find('/') == -1):
                    model = os.path.join(model_dir , model)
                # vae_path = os.path.join(vae_dir , "vae-ft-mse-840000-ema-pruned.safetensors")
                # vae = AutoencoderKL.from_single_file(vae_path)
                # ダウンロードしたVAEファイルを指定
                # vae = AutoencoderKL.from_single_file("./vae/sdxl_vae.safetensors")



                print(f"■Model: {model}")
                print(f"■Lora: {lora}")
                print(f"■Prompt: {prompt}")
                print(f"■Negative Prompt: {negative_prompt}")
                # print(f"■VAE: {vae_path}")
                print(f"■Width: {width}")
                print(f"■Height: {height}")

                # print(width * 4)
                # sys.exit()

                # sys.exit()

                print("---------------start--------------")

                # ダウンロードしたファイル名を指定して、Pipeline作成
                # pipe = DiffusionPipeline.from_single_file(
                #     "Counterfeit-V3.0_fp16.safetensors", torch_dtype=torch.float16
                # ).to("cuda")


                pipe = StableDiffusionXLPipeline.from_pretrained(
                    model,
                    torch_dtype=torch.float16,
                    # torch_dtype=torch.float32,
                    # vae=vae,
                    # use_safetensors=True,
                    # safety_checker = None,
                    # requires_safety_checker = False
                )

                # pipe = DiffusionPipeline.from_pretrained(
                #     model,
                #     torch_dtype=torch.float16,
                #     # torch_dtype=torch.float32,
                #     # vae=vae,
                #     # use_safetensors=True,
                #     # safety_checker = None,
                #     # requires_safety_checker = False
                # )
                # try:
                #     pipe = StableDiffusionXLPipeline.from_pretrained(
                #         model,
                #         torch_dtype=torch.float16,
                #         # torch_dtype=torch.float32,
                #         # vae=vae,
                #         # use_safetensors=True,
                #         # safety_checker = None,
                #         # requires_safety_checker = False
                #     )
                # except Exception as e:
                #     pipe = DiffusionPipeline.from_pretrained(
                #         model,
                #         torch_dtype=torch.float16,
                #         # torch_dtype=torch.float32,
                #         # vae=vae,
                #         # use_safetensors=True,
                #         # safety_checker = None,
                #         # requires_safety_checker = False
                #     )
                # pipe = DiffusionPipeline.from_single_file(
                #     model,
                #     vae=AutoencoderKL.from_single_file("vae/Counterfeit-V2.5.vae.pt"),
                #     load_safety_checker=False,
                #     extract_ema=True
                # )
                # 黒画像回避
                pipe.safety_checker = None
                pipe.requires_safety_checker = False

                # pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
                pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
                pipe.to('cuda')

                # 更なるメモリ節約のための sliced アテンション
                # pipe.enable_attention_slicing()

                # CPU でのシーケンシャルなオフロードを有効にする
                # pipe.enable_sequential_cpu_offload()

                if(lora != None):
                    pipe.load_lora_weights(
                        lora_dir,
                        weight_name= lora
                    )

                # if pipe.safety_checker is not None:
                #     pipe.safety_checker = lambda images, **kwargs: (images, False)
                # if pipe.safety_checker is not None:
                #     pipe.safety_checker = lambda images, **kwargs: (images, False)



                current_time = time.strftime("%Y%m%d%H%M%S")

                output = "content/"+str(current_time)+".png"

                generator = torch.Generator()
                generator.manual_seed(2833846766)

                image = pipe(
                    prompt,
                    negative_prompt=negative_prompt,
                    # width=832,
                    # height=1216,
                    width=width,
                    height=height,
                    # width=1024,
                    # height=1024,
                    guidance_scale=5,
                    # num_inference_steps=28
                    num_inference_steps=50,
                    # generator=generator,
                    target_size=((width * 4),(height * 4)),
                    original_size=((width * 8),(height * 8)),
                ).images[0]


                # --------------------------
                # model_id = "stabilityai/sd-x2-latent-upscaler"
                # upscaler = StableDiffusionLatentUpscalePipeline.from_pretrained(model_id, torch_dtype=torch.float16)
                # upscaler.to("cuda")

                # image = upscaler(
                #     prompt=prompt,
                #     image=image,
                #     num_inference_steps=20,
                #     guidance_scale=0,
                #     # generator=generator,
                # ).images[0]
                # --------------------------


                image.save(output)
                image = Image.open(output)
                plt.imshow(image)
                plt.axis('off') # to hide the axis



                # notify_discord(output)

                print("line notify")
                notify_line(output)
                print("line notify end")



                url = config['update_url']
                obj = {
                    "id": data['id'],
                }
                print("post")
                # response = requests.post(url,json=obj,verify=False)
                response = requests.post(
                    url,
                    data=obj,
                    verify=False,
                    files={'file': open(output, 'rb')}
                )

                if response.status_code == 200:
                    print("Success!")
                    print(response.json())  # サーバーからのレスポンスを出力
                else:
                    print("POST request failed with status code:", response.status_code)

                print("posted")



                # sys.exit()

            except Exception as e:
                print(f"Error: {e}")
                continue
            finally:
                print("---------------end--------------")
                print("")




