-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
88 lines (73 loc) · 2.85 KB
/
main.py
File metadata and controls
88 lines (73 loc) · 2.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import io
import json
import os
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from diffusers import AutoPipelineForText2Image
from fastapi.responses import StreamingResponse, JSONResponse
import torch
# This code expects the config file to point out all settings needed to run
config_path = os.getenv("SD_API_CONFIG_PATH", "config.json")
with open(config_path, 'r') as f:
config = json.load(f)
# API Start
app = FastAPI(root_path=config["root"])
# Configuration variables
repo_id = config["model"]
variant = config["variant"]
class Item(BaseModel):
prompt: str
inference_steps: int | None = 40
guiding_scale: float | None = 7.5
pipeline = AutoPipelineForText2Image.from_pretrained(
repo_id,
variant=variant,
torch_dtype=torch.bfloat16,
use_safetensors=True
).to("cuda")
# Setup to be more memory efficient, so we don't run out of NVRam
pipeline.enable_vae_slicing()
pipeline.enable_vae_tiling()
pipeline.enable_attention_slicing(1)
pipeline.enable_model_cpu_offload()
gen = torch.Generator(device="cuda")
@app.post(config["endpoint"],
# Set response template to respond with an image
responses={
200: {
"content": {"image/jpeg": {}}
},
500: {
"content": {"application/json": {}}
}
},
# Make sure we don't add 'application/json' as a response when sending image. This is an image, nothing else
response_class=StreamingResponse
)
def generate_image(item: Item):
"""
This function uses the FastAPI POST method to generate an image based on prompts provided by the user.
Args:
item (Item): An Item object that contains the following parameters:
prompt (str): the textual prompt for the image generation
inference_steps (int, optional): the number of inference steps. Defaults to 40.
guiding_scale (float, optional): guiding scale. Defaults to 7.5.
seed (int, optional): seed. Defaults to 17.
Returns:
response (Response): image in JPEG format, 1024x1024 pixels generated by model based on provided prompt
"""
prompt = item.prompt
image = pipeline(prompt,
num_inference_steps=item.inference_steps,
guidance_scale=item.guiding_scale,
generator=gen,
width=1024,
height=1024
).images[0]
# Make a quick smoke test to see that we have data to send back
jpeg_image = io.BytesIO()
image.save(jpeg_image, format="JPEG")
jpeg_image.seek(0)
if jpeg_image.getbuffer().nbytes == 0:
raise HTTPException(status_code=500, detail="Failed to generate image")
return StreamingResponse(jpeg_image, media_type="image/jpeg")