-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathserver.py
More file actions
134 lines (110 loc) · 3.75 KB
/
server.py
File metadata and controls
134 lines (110 loc) · 3.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from flask import Flask, request
from flask_cors import CORS
import os
import werkzeug
import ast
import csv
from pathlib import Path
from werkzeug.utils import secure_filename
from werkzeug.datastructures import FileStorage
from flask_restx import Resource, Api
from werkzeug.middleware.proxy_fix import ProxyFix
import silknow_image_retrieval as sir
werkzeug.cached_property = werkzeug.utils.cached_property
VISUAL_MODEL_NAME = "visual_image_retrieval_v2"
SEMANTIC_MODEL_NAME = "visual_and_semantic_retrieval_v2"
model_visual = sir.preload_cnn_model(
"./output_files/models/" + VISUAL_MODEL_NAME + "/"
)
model_semantic = sir.preload_cnn_model(
"./output_files/models/" + SEMANTIC_MODEL_NAME + "/"
)
kd_tree_visual = sir.preload_kd_tree(
"./output_files/trees/" + VISUAL_MODEL_NAME + "/"
)
kd_tree_semantic = sir.preload_kd_tree(
"./output_files/trees/" + SEMANTIC_MODEL_NAME + "/"
)
def process(model_name, model, kd_tree):
(
tree,
labels_tree,
data_dict_train,
relevant_variables,
label2class_list,
) = kd_tree
sir.get_kNN_from_preloaded_cnn_and_tree(
tree,
labels_tree,
data_dict_train,
relevant_variables,
label2class_list,
master_file_retrieval="master_file_retrieval.txt",
master_dir_retrieval=r"./samples/",
model=model,
pred_gt_dir="./output_files/preds/" + model_name + "/",
num_neighbors=20,
)
uris = list()
with open("output_files/preds/" + model_name + "/kNN_LUT.csv") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
next(csv_reader, None) # skip the headers
for row in csv_reader:
uris_fragments = ast.literal_eval(row[2])
if uris_fragments:
uris = list(
map(
lambda x: "http://data.silknow.org/object/" + x,
uris_fragments,
)
)
return uris
print("Starting web server...")
app = Flask(__name__)
app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1, x_proto=1, x_host=1, x_port=1)
api = Api(app, doc="/doc")
cors = CORS(app)
ns = api.namespace("api", description="Image Retrieval API")
@ns.route("/status")
class status_route(Resource):
def get(self):
return "OK"
upload_parser = api.parser()
upload_parser.add_argument(
"file", location="files", type=FileStorage, required=True
)
@ns.route("/retrieve", methods=["POST"])
@api.expect(upload_parser)
class retrieve_route(Resource):
@ns.doc("retrieve_route")
def post(self):
# Save image
f = request.files["file"]
filepath = os.path.join("files", secure_filename(f.filename))
Path(os.path.dirname(filepath)).mkdir(parents=True, exist_ok=True)
f.save(filepath)
collection_file = os.path.join("samples", "collection.txt")
# Remove previous collection file if it exists
if os.path.exists(collection_file):
os.remove(collection_file)
# Add image to list of images to process
with open(collection_file, "w") as f:
f.write("#image\t#Label\n")
f.write(os.path.join("..", filepath) + "\tImage\n")
# Process image with Visual model
visual_uris = process(
model_name=VISUAL_MODEL_NAME,
model=model_visual,
kd_tree=kd_tree_visual,
)
semantic_uris = process(
model_name=SEMANTIC_MODEL_NAME,
model=model_semantic,
kd_tree=kd_tree_semantic,
)
return {"visualUris": visual_uris, "semanticUris": semantic_uris}
@api.errorhandler
def default_error_handler(error):
return {"error": str(error)}, getattr(error, "code", 500)
if __name__ == "__main__":
app.run(debug=True)