-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
48 lines (37 loc) · 1.02 KB
/
main.py
File metadata and controls
48 lines (37 loc) · 1.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import json
import os
import torch
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import pipeline, AutoTokenizer
config_path = os.getenv("CLASSIFY_API_CONF", "config.json")
with open(config_path, "r") as f:
config = json.load(f)
app = FastAPI(root_path=config["root"])
class Item(BaseModel):
text: str
labels: list[str]
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path=config["model"],
model_max_length=512,
use_fast=True,
)
classifier = pipeline(
"zero-shot-classification",
model=config["model"],
device=0 if torch.cuda.is_available() else -1,
batch_size=8,
torch_dtype=torch.float16,
tokenizer=tokenizer,
)
@app.post(config["endpoint"])
async def classify(item: Item):
hypothesis_template = config["template"]
results = classifier(
sequences=item.text,
candidate_labels=item.labels,
hypothesis_template=hypothesis_template,
multi_label=True,
max_length=512,
)
return results