import torch
from diffusers import StableDiffusionPipeline
from diffusers import DPMSolverMultistepScheduler

import matplotlib.pyplot as plt
from PIL import Image
import time

from diffusers.utils import make_image_grid

def get_inputs(batch_size=1):
    generator = [torch.Generator("cuda").manual_seed(i) for i in range(batch_size)]
    prompts = batch_size * [prompt]
    num_inference_steps = 20

    return {"prompt": prompts, "generator": generator, "num_inference_steps": num_inference_steps}


current_time = time.strftime("%Y%m%d%H%M%S")
output = "content/"+str(current_time)+".png"

##############################
# pipe = StableDiffusionPipeline.from_pretrained(
#     # "CompVis/stable-diffusion-v1-4",
#     "cagliostrolab/animagine-xl-3.1",
#     # revision="fp16",
#     torch_dtype=torch.float16,
#     # use_auth_token=True
# )
# pipe = pipe.to("cuda")

# prompt = "a photo of an astronaut riding a horse on mars"
# pipe.enable_attention_slicing()
# with torch.autocast("cuda"):
#     image = pipe(prompt).images[0]

#     image.save(output)
#     image = Image.open(output)
#     plt.imshow(image)
#     plt.axis('off') # to hide the axis
##############################


# model_id = "runwayml/stable-diffusion-v1-5"
model_id = "cagliostrolab/animagine-xl-3.1"
# pipeline = StableDiffusionPipeline.from_pretrained(model_id, use_safetensors=True)

pipeline = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, use_safetensors=True)

# DPMSolverMultistepScheduler
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)



prompt = "portrait photo of a old warrior chief"

pipe = pipeline.to("cuda")





generator = torch.Generator("cuda").manual_seed(0)
image = pipeline(prompt, generator=generator, num_inference_steps=20).images[0]
# image = pipe(prompt).images[0]

image.save(output)
image = Image.open(output)
plt.imshow(image)
plt.axis('off') # to hide the axis
