This repository was archived by the owner on Jul 31, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathapi.py
More file actions
45 lines (37 loc) · 1.26 KB
/
api.py
File metadata and controls
45 lines (37 loc) · 1.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import os
import torch
from accelerate import Accelerator
from fastapi import FastAPI, HTTPException
from typing import List, Dict
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
from inference import generate_text
app = FastAPI()
# Initialize Accelerator
accelerator = Accelerator()
# Load model and tokenizer
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
# Prepare model for distributed training
model = accelerator.prepare(model)
class GenerationRequest(BaseModel):
messages: List[Dict[str, str]]
max_new_tokens: int = 100
do_sample: bool = True
temperature: float = 0.7
top_p: float = 0.9
@app.post("/text")
@app.post("/chat/completions")
@app.post("/v1/chat/completions")
@app.post("/generate")
async def generate(request: GenerationRequest):
try:
args = request.model_dump()
response = generate_text(args, model, tokenizer, accelerator)
return response
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7099)