diff --git a/classifybot.py b/classifybot.py index b592059..faf9c71 100644 --- a/classifybot.py +++ b/classifybot.py @@ -1,134 +1,403 @@ -#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ -Created on Wed May 18 16:15:19 2022 +Created on Wed Jul 20 11:32:58 2022 -@author: jack +@author: jced0001 """ +import zulip +import torch +import numpy as np +import train_model as cnn +from train_model import ConvNet +from get_zulip_data import getZulipData +import imageio as iio +import wget import os -import time +class classifybot(object): +############################################################################### +# Constructor +############################################################################### + def __init__(self): + self.init() + + self.getWhitelist() # Load in whitelist file if there is one + self.initCommandDict() # Initialise dictionary containing all of classifybot's commands + +############################################################################### +# Initialisation +############################################################################### + def init(self): + with open('config.ini','r') as f: # Open the config file + config = f.read() # config.ini must contain all of the following: + self.zuliprc = config.split('zulip_rc_path=')[1].split('\n')[0] # Path to the zulip bot's rc file + self.scanbotAddress = config.split('scanbot_address=')[1].split('\n')[0] # scanbot's email address + self.scanbotStream = config.split('scanbot_stream=')[1].split('\n')[0] # stream to walk through + + self.notifications = True + self.bot_message = "" + + self.zulipClient = [] + if(self.zuliprc): + self.zulipClient = zulip.Client(config_file=self.zuliprc) + + def getWhitelist(self): + self.whitelist = [] + try: + with open('whitelist.txt', 'r') as f: + d = f.read() + self.whitelist = d.split('\n')[:-1] + except: + print('No whitelist... add users to create one') -import tensorflow as tf -from tf_custom_metric import macro_soft_f1 + def initCommandDict(self): + self.commands = {'list_commands' : self.listCommands, + 'help' : self._help, + 'add_user' : self.addUsers, + 'list_users' : lambda args: str(self.whitelist), + 'stop' : self.stop, + 'get_model' : self.getModel, + 'train_model' : self.trainModel, + 'load_model' : self.loadModel, + 'load_zulip_data' : self.loadZulipData, + 'predict' : self.predict + } + +############################################################################### +# Zulip +############################################################################### + def handle_message(self, message, bot_handler=None): + messageContent = message + self.bot_message = [] + self.bot_handler = bot_handler + if(bot_handler): + if message['sender_email'] not in self.whitelist and self.whitelist: + self.sendReply(message['sender_email']) + self.sendReply('access denied') + return + + self.bot_message = message + messageContent = message['content'] + + command = messageContent.split(' ')[0].lower() + args = messageContent.split(' ')[1:] + + if(not command in self.commands): + reply = "Invalid command. Run *list_commands* to see command list" + self.sendReply(reply) + return + + reply = self.commands[command](args) + + if(reply): self.sendReply(reply) + + def sendReply(self,reply,message=""): + """ + Send reply text. Currently only supports zulip and console. -import wget + Parameters + ---------- + reply : Reply string + message : Zulip: message params for the specific message to reply ro. + If not passed in, replies to the last message sent by user. -import pickle + Returns + ------- + message_id : Returns the message id of the sent message. (zulip only) -import numpy as np + """ + if(not reply): return # Can't send nothing + if(self.notifications): # Only send reply if notifications are turned on + if(self.bot_handler): # If our reply pathway is zulip + replyTo = message # If we're replying to a specific message + if(not replyTo): replyTo = self.bot_message # If we're just replying to the last message sent by user + self.bot_handler.send_reply(replyTo, reply) # Send the message + return + + print(reply) # Print reply to console + + def reactToMessage(self,reaction,message=""): + """ + Scanbot emoji reaction to message -import zulip + Parameters + ---------- + reaction : Emoji name (currently zulip only) + message : Specific zulip message to react to. If not passed in, reacts + to the last message sent by user. -class classifybot(object): + """ + if(not self.bot_handler): # If we're not using zulip + print("Scanbot reaction: " + reaction) # Send reaction to console + return + reactTo = message # If we're reacting to a specific zulip message + if(not reactTo): reactTo = self.bot_message # Otherwise react to the last user message + react_request = { + 'message_id': reactTo['id'], # Message ID to react to + 'emoji_name': reaction, # Emoji scanbot reacts with + } + self.zulipClient.add_reaction(react_request) # API call to react to the message + +############################################################################### +# Classifybot Commands +############################################################################### + def listCommands(self,args): + return "\n". join([c for c in self.commands]) - def __init__(self): - self.model = tf.keras.models.load_model('kf_model.model', custom_objects={'macro_soft_f1':macro_soft_f1}) - self.class_names = pickle.load(open('class_names.pkl', 'rb')) - self.client = zulip.Client(config_file='zuliprc') - self.threshold = .8 - self.tip_conditioning = False - with open('config.ini', 'r') as f: - config = f.read() - self.scanbot_address = config.split('scanbot_address=')[1].split('\n')[0] - self.scanbot_handle = config.split('scanbot_handle=')[1].split('\n')[0] - self.classifybot_address = config.split('classifybot_address=')[1].split('\n')[0] + def addUsers(self,user,_help=False): + arg_dict = {'' : ['', 0, "(string) Add user email to whitelist (one at a time)"]} - def handle_message(self, message, bot_handler=None): - self.bot_handler = bot_handler + if(_help): return arg_dict - if 'set_threshold' in message['content']: - self.threshold = float(message['content'].split(' ')[1]) - # bot_handler.send_reply(message, 'threshold is ' + str(self.threshold)) - react_request = { - 'message_id': message['id'], - 'emoji_name': '+1', - } - _ = self.client.add_reaction(react_request) - return + if(len(user) != 1): self.reactToMessage("cross_mark"); return + if(' ' in user[0]): self.reactToMessage("cross_mark"); return + try: + self.whitelist.append(user[0]) + with open('whitelist.txt', 'w') as f: + for w in self.whitelist: + f.write(w+'\n') + except Exception as e: + return str(e) - if 'condition_tip' in message['content']: - self.tip_conditioning = True - react_request = { - 'message_id': message['id'], - 'emoji_name': '+1', - } - _ = self.client.add_reaction(react_request) - return + self.reactToMessage("+1") + + def stop(self,args): + pass + + def trainModel(self,user_args,_help=False): + arg_dict = {'-name' : ['-default', lambda x: str(x), "(str) Name the model"], + '-target' : ['-default', lambda x: str(x), "(str) Target label (binary implementation for now)"], + '-pklpath' : ['-default', lambda x: str(x), "(str) Path to pickled data"], + '-augment' : ['-default', lambda x: int(x), "(int) Augment data (reflections)"], + '-load' : ['1', lambda x: int(x), "(int) Auto load model after training. 0=No, 1=Yes"]} + + if(_help): return arg_dict + + error,user_arg_dict = self.userArgs(arg_dict,user_args) + if(error): return error + "\nRun ```help train_model``` if you're unsure." + + with open('config.ini','r') as f: # Open the config file + config = f.read() # config.ini must contain all of the following: + default_args = {'-name' : config.split('run_name=')[1].split('\n')[0], # Name of the run. pkl saved as zulipData-runName-batch_x.pkl + '-target' : config.split('target_label=')[1].split('\n')[0], # The label with respect to which we try and classify. Binary for now + '-pklpath' : config.split('pkl_path=')[1].split('\n')[0], # Output path for pkl'd data + '-augment' : config.split('augment_data=')[1].split('\n')[0]} # Flag to augment data during training + + for key in user_arg_dict: + if(user_arg_dict[key][0] == "-default"): + user_arg_dict[key][0] = default_args[key] + + args = self.unpackArgs(user_arg_dict) + + self.reactToMessage("gym") + try: + modelPath = cnn.trainNewCNN(*args[0:4]) + self.reactToMessage("muscle") + if(user_arg_dict['-load'][0] == '0'): return + self.model = ConvNet(load_model=modelPath) + self.reactToMessage("computer") + except Exception as e: + return str(e) + + with open('config.ini','r+') as f: + config = str(f.read()) + oldModel = config.split('run_name=')[1].split('\n')[0] + oldTarget = config.split('target_label=')[1].split('\n')[0] + oldpkl = config.split('pkl_path=')[1].split('\n')[0] + oldaug = config.split('augment_data=')[1].split('\n')[0] - if 'stop' in message['content']: - self.tip_conditioning = False - react_request = { - 'message_id': message['id'], - 'emoji_name': '+1', - } - _ = self.client.add_reaction(react_request) - return + config = config.replace('run_name=' + oldModel, + 'run_name=' + user_arg_dict['-name'][0]) + config = config.replace('target_label=' + oldTarget, + 'target_label=' + user_arg_dict['-target'][0]) + config = config.replace('pkl_path=' + oldpkl, + 'pkl_path=' + user_arg_dict['-pklpath'][0]) + config = config.replace('augment_data=' + oldaug, + 'augment_data=' + user_arg_dict['-augment'][0]) + + with open('config.ini','w') as f: + f.write(config) + + def loadModel(self,user_args,_help=False): + arg_dict = {'-name': ['-default', lambda x: str(x), "(str) Name the model"]} + + if(_help): return arg_dict - if(bot_handler): - ## get the png - # print(message['content']) + error,user_arg_dict = self.userArgs(arg_dict,user_args) + if(error): return error + "\nRun ```help laod_model``` if you're unsure." + + with open('config.ini','r') as f: # Open the config file + config = f.read() # config.ini must contain all of the following: + default_args = {'-name': config.split('run_name=')[1].split('\n')[0]} # Name of the run. pkl saved as zulipData-runName-batch_x.pkl + + for key in user_arg_dict: + if(user_arg_dict[key][0] == "-default"): + user_arg_dict[key][0] = default_args[key] + + args = self.unpackArgs(user_arg_dict) + + try: + self.model = ConvNet(*args) + print(self.model) + except Exception as e: + return str(e) + + with open('config.ini','r+') as f: + config = str(f.read()) + oldModel = config.split('run_name=')[1].split('\n')[0] + config = config.replace('run_name=' + oldModel, + 'run_name=' + user_arg_dict['-name'][0]) + + with open('config.ini','w') as f: + f.write(config) + + self.reactToMessage("computer") + + def predict(self,user_args,_help=False): + try: + try: url = self.bot_message['content'].split('')[0].replace('&', '&') + except: url = self.bot_message['content'].split('(')[1].split(')')[0] + filename = wget.download(url=url) + im = iio.imread(filename) + os.remove(filename) + im = np.array(im/np.max(im),dtype=np.float32) # Normalise the data and force to be float32 + x = [] + x.append(np.transpose(im[:,:,:3], (2,0,1))) # Append the image to x + x = np.array(x) # Convert to numpy array + x = torch.tensor(x) # Convert to a pytorch compatible tensor + print("predicting") + prediction = self.model(x[0:1]) + self.reactToMessage(["cross_mark","bulls_eye"][np.argmax(prediction.detach())]) + except Exception as e: + self.sendReply(str(e)) + self.reactToMessage("no_entry") + + def loadZulipData(self,user_args,_help=False): + arg_dict = {'-name' : ['-default', lambda x: str(x), "(str) Name the model"], + '-msgid' : ['-default', lambda x: str(x), "(str) Last message ID"], + '-target' : ['-default', lambda x: str(x), "(str) Target label (binary implementation for now)"], + '-pklpath' : ['-default', lambda x: str(x), "(str) Path to pickled data"], + '-augment' : ['-default', lambda x: int(x), "(int) Augment data (reflections)"]} + + if(_help): return arg_dict + + error,user_arg_dict = self.userArgs(arg_dict,user_args) + if(error): return error + "\nRun ```help train_model``` if you're unsure." + + with open('config.ini','r') as f: # Open the config file + config = f.read() # config.ini must contain all of the following: + default_args = {'-name' : config.split('run_name=')[1].split('\n')[0], # Name of the run. pkl saved as zulipData-runName-batch_x.pkl + '-msgid' : config.split('last_message_id=')[1].split('\n')[0],# Message ID anchor (get messages from this message ID) + '-target' : config.split('target_label=')[1].split('\n')[0], # The label with respect to which we try and classify. Binary for now + '-pklpath' : config.split('pkl_path=')[1].split('\n')[0], # Output path for pkl'd data + '-augment' : config.split('augment_data=')[1].split('\n')[0]} # Flag to augment data during training + + for key in user_arg_dict: + if(user_arg_dict[key][0] == "-default"): + user_arg_dict[key][0] = default_args[key] + + config_bk = [] + with open('config.ini','r+') as f: + config = str(f.read()) + config_bk = config # Keep a backup of this config in case we error, then put back the old config + oldModel = config.split('run_name=')[1].split('\n')[0] + oldMsgID = config.split('last_message_id=')[1].split('\n')[0] + oldTarget = config.split('target_label=')[1].split('\n')[0] + oldpkl = config.split('pkl_path=')[1].split('\n')[0] + oldaug = config.split('augment_data=')[1].split('\n')[0] + + config = config.replace('run_name=' + oldModel, + 'run_name=' + user_arg_dict['-name'][0]) + config = config.replace('last_message_id=' + oldMsgID, + 'last_message_id=' + user_arg_dict['-msgid'][0]) + config = config.replace('target_label=' + oldTarget, + 'target_label=' + user_arg_dict['-target'][0]) + config = config.replace('pkl_path=' + oldpkl, + 'pkl_path=' + user_arg_dict['-pklpath'][0]) + config = config.replace('augment_data=' + oldaug, + 'augment_data=' + user_arg_dict['-augment'][0]) + + with open('config.ini','w') as f: + f.write(config) + + try: + self.reactToMessage("working_on_it") + getZulipData() + self.reactToMessage("computer") + except Exception as e: + with open('config.ini','w') as f: + f.write(config_bk) + self.reactToMessage("-1") + return str(e) + + def getModel(self,user_args,_help=False): + if(_help): return + + with open('config.ini','r+') as f: + config = str(f.read()) + model = config.split('run_name=')[1].split('\n')[0] + + return(model) +############################################################################### +# Utilities +############################################################################### + def userArgs(self,arg_dict,user_args): + error = "" + for arg in user_args: # Override the defaults if user inputs them + try: + key,value = arg.split('=') + except: + error = "Invalid argument" + break + if(not key in arg_dict): + error = "invalid argument: " + key # return error message + break try: - # url = message['content'].split('')[0].replace('&', '&') - url = message['content'].split('(')[1].split(')')[0] #.replace('&', '&') - file = wget.download(url) - - ## from https://www.tensorflow.org/tutorials/images/classification - img_height = 224 - img_width = 224 - - img = tf.keras.utils.load_img(file, target_size=(img_height, img_width)) - img_array = tf.keras.utils.img_to_array(img) - img_array = tf.expand_dims(img_array, 0) - - predictions = self.model.predict(img_array) - score = np.array(tf.nn.softmax(predictions[0])) - - ## send reactions for categories above 80%: - score /= np.amax(np.array(score)) - class_names = np.array(self.class_names) - emojis = class_names[score > self.threshold] - - for emoji in emojis: - react_request = { - 'message_id': message['id'], - 'emoji_name': emoji, - } - _ = self.client.add_reaction(react_request) + arg_dict[key][1](value) # Validate the value + except: + error = "Invalid value for arg " + key + "." # Error if the value doesn't match the required data type + break - os.remove(file) - except Exception as e: - bot_handler.send_reply(message, e) - - if self.tip_conditioning: - time.sleep(15) - result = self.client.get_raw_message(message['id']) - labels = [] - for reaction in result['message']['reactions']: - if reaction['user']['email'] != self.classifybot_address: - labels.append(reaction['emoji_name']) - ## if no human labelling, use classifybot labels - if len(labels) < 1: - for reaction in result['message']['reactions']: - labels.append(reaction['emoji_name']) - - tip_shape = False - - bad_emojis = ['-1', 'barber', 'bow_and_arrow', 'duel', 'poop', 'two', 'temperature'] - good_emojis = ['+1', 'fire', 'flame', 'knife', 'sparkling_heart', 'tada'] - - if any([label in bad_emojis for label in labels]): - tip_shape = True - - if all([label in good_emojis for label in labels]): - return - - if tip_shape: - reply = "@**" + self.scanbot_handle + "** tip_shape" - bot_handler.send_reply(message, reply) - return - - + arg_dict[key][0] = value + + return [error,arg_dict] + + def unpackArgs(self,arg_dict): + args = [] + for key,value in arg_dict.items(): + if(value[0] == "-default"): # If the value is -default... + args.append("-default") # leave it so the function can retrieve the value from nanonis + continue + args.append(value[1](value[0])) # Convert the string into data type + + return args + + def _help(self,args): + if(not len(args)): + helpStr = "Type ```help ``` for more info\n" + return helpStr + self.listCommands(args=[]) + + command = args[0] + if(not command in self.commands): + return "Run ```list_commands``` to see valid commands" + + try: + helpStr = "**" + command + "**\n" + arg_dict = self.commands[command](args,_help=True) + for key,value in arg_dict.items(): + if(key): + helpStr += "```" + helpStr += key + "```: " + helpStr += value[2] + ". " + if(value[0]): + helpStr += "Default: ```" + value[0].replace("-default","config file") + helpStr += "```" + helpStr += "\n" + except: + return "No help for this command" + + return helpStr + handler_class = classifybot \ No newline at end of file diff --git a/get_zulip_data.py b/get_zulip_data.py index f3031c2..a9af9a1 100644 --- a/get_zulip_data.py +++ b/get_zulip_data.py @@ -3,178 +3,178 @@ """ Created on Tue May 17 11:05:48 2022 -pull in image pngs on firebase indexed on zulip thread, and get associated emoji reactions - -@author: jack +config.ini must contain +run_name= # Name given to the run. Outut is saved as .pkl +zulip_rc_path= +scanbot_address= +scanbot_stream= +last_message_id= # This gets overwritten after the run. set to 0 to start from begining of time +label_dict= # In the form of emoji1:label1,emoji2:label2. useful for same emoji with multiple names (e.g. -1 and thumbs_down) + +This code scans beginning at . Selects all +messages sent from directed at @ + +@author: Julian Ceddia & Jack Hellerstedt """ -import os -# import time import zulip -import wget import pickle - - -# make database files -with open('config.ini', 'r') as f: - config = f.read() - image_data_path = config.split('image_database_directory=')[1].split('\n')[0] - zuliprc_name = config.split('zuliprc_name=')[1].split('\n')[0] - classifybot_name = config.split('classifybot_name=')[1].split('\n')[0] - scanbot_address = config.split('scanbot_address=')[1].split('\n')[0] - scanbot_stream = config.split('scanbot_stream=')[1].split('\n')[0] - -try: - os.mkdir(image_data_path) -except: - print('image_data dir already exists') - -if not 'batch_0' in os.listdir(image_data_path): - os.mkdir(os.path.join(image_data_path, 'batch_0')) - - - -## specify zuliprc file -zuliprc_path = os.path.join(os.getcwd(), zuliprc_name) -client = zulip.Client(config_file=zuliprc_path) - -## specify hard-coded classifybot name -# classifybot_name = 'classifybot' - - - - -## specify stream & scanbot email address to read messages: -request = {} -## define the narrow -request['narrow'] = [ - {"operator": "sender", "operand": scanbot_address}, - {"operator": "stream", "operand": scanbot_stream}, - # {"operator": "topic", "operand": "survey"}, - ] - -## get id of first unread message -request['anchor'] = 'first_unread' -request['num_before'] = 0 -request['num_after'] = 0 - -##go for oldest -# request['anchor'] = 'oldest' -# request['num_before'] = 0 -# request['num_after'] = 1 - -result = client.get_messages(request) -print('first_unread anchor') -print(result) - -first_unread_id = result['messages'][0]['id'] - -## get id of newest message -request['anchor'] = 'newest' -request['num_before'] = 1 -request['num_after'] = 0 - -result = client.get_messages(request) -print('newest anchor') -print(result) - -newest_message_id = result['messages'][0]['id'] - -to_mark_read = [] -keep_unread = [] -for message_id in range(first_unread_id, newest_message_id, 100): - request['anchor'] = message_id - request['num_before'] = 0 - request['num_after'] = 100 +import wget +import os +import imageio as iio + +def getZulipData(): + with open('config.ini','r') as f: + config = f.read() # config.ini must contain all of the following: + runName = config.split('run_name=')[1].split('\n')[0] # Name of the run. pkl saved as zulipData-runName-batch_x.pkl + zulipRcPath = config.split('zulip_rc_path=')[1].split('\n')[0] # Path to the zulip bot's rc file + scanbotAddress = config.split('scanbot_address=')[1].split('\n')[0] # scanbot's email address + scanbotStream = config.split('scanbot_stream=')[1].split('\n')[0] # stream to walk through + lastMsgID = int(config.split('last_message_id=')[1].split('\n')[0]) # Start searching from this message ID. Doesn't have to be an exact message ID match. 0 starts from begining. Autoupdated at the end of a run to pick up from where you left off + batchSize = int(config.split('batch_size=')[1].split('\n')[0]) # Number of images per pkl file + nbatch = int(config.split('nbatch=')[1].split('\n')[0]) # Number of batches to process -1 for entire stream. 1 batch per pkl + userList = config.split('user_list=')[1].split('\n')[0].split(',') # Users to read labels from + labelDictIni = config.split('label_dict=')[1].split('\n')[0] # Map emojis to labels... In the form of emoji1:label1,emoji2:label2. useful for same emoji with multiple names (e.g. -1 and thumbs_down) + pklPath = config.split('pkl_path=')[1].split('\n')[0] # Output path for pkl'd data - ## check which batch folder to put images in - folders = os.listdir(image_data_path) - max_batch_index = 0 - for name in folders: - if name.split('_')[0] == 'batch': - batch_index = int(name.split('_')[1]) - if batch_index > max_batch_index: - max_batch_index = batch_index + labelDict = {} # Convert labelDictini to python dictionary... + labelDictIni = labelDictIni.split(',') + for ld in labelDictIni: + key,value = ld.split(':') + labelDict[key] = value - batch_path = os.path.join(image_data_path, 'batch_' + str(max_batch_index)) + if(not pklPath.endswith('/')): pklPath += '/' + pklPath += runName + '/' + pklParams = {"runName": runName, # Store config params in the final pickle file + "zulipRcPath": zulipRcPath, + "scanbotAddress": scanbotAddress, + "scanbotStream": scanbotStream, + "lastMsgID": lastMsgID, + "userList": userList, + "batchSize": batchSize, + "nbatch": nbatch, + "labelDict": labelDict} - if len(os.listdir(batch_path)) > 256: - batch_path = os.path.join(image_data_path, 'batch_' + str(max_batch_index+1)) - os.mkdir(batch_path) - pickle.dump(True, open('retrain_flag.pkl', 'wb')) + client = zulip.Client(config_file=zulipRcPath) # Zulip Client + handle = client.get_profile()['full_name'] # Bot's handle + try: os.mkdir(pklPath) + except: pass + try: os.mkdir(pklPath + "labelled") + except: pass + try: os.mkdir(pklPath + "unlabelled") + except: pass try: - batch_labels = pickle.load(open(os.path.join(batch_path,'file_labels.pkl'), 'rb')) - except: - print('no batch labels ' + batch_path) - batch_labels = {} + labelledBatchNo = max([int(b.split('batch_')[1].split('.pkl')[0]) for b in # If there are any pkl's with this runName, check what batch number we're up to + [f for f in os.listdir(pklPath + 'labelled/') + if('zulipData-' + runName + '-labelled-' in f)]]) + labelledData = pickle.load(open(pklPath + 'labelled/zulipData-' + runName + '-labelled-batch_' # load in the batch file to continue adding to it + + str(labelledBatchNo) + '.pkl','rb')) + labelledData = labelledData['data'] # just get the data from it + if(len(labelledData) >= batchSize): # If this batch is full, start from the next one + labelledBatchNo += 1 + labelledData = {} + except: # If no batches with runName, start from 0 + labelledBatchNo = 0 + labelledData = {} - - results = client.get_messages(request) - if results['result'] == 'success': - for message in results['messages']: - if '
' in message['content'] and 'read' not in message['flags']: - url = message['content'].split('')[0].replace('&', '&') - labels = [] - for reaction in message['reactions']: - if reaction['user']['full_name'] != classifybot_name: - labels.append(reaction['emoji_name']) - - if len(labels) > 0 and '.png' in url: + try: + unlabelledBatchNo = max([int(b.split('batch_')[1].split('.pkl')[0]) for b in# If there are any pkl's with this runName, check what batch number we're up to + [f for f in os.listdir(pklPath + 'unlabelled/') + if('zulipData-' + runName + '-unlabelled-' in f)]]) + unlabelledData = pickle.load(open(pklPath + 'unlabelled/zulipData-' + runName + '-unlabelled-batch_' # load in the batch file to continue adding to it + + str(unlabelledBatchNo) + '.pkl','rb')) + unlabelledData = unlabelledData['data'] # just get the data from it + if(len(unlabelledData) >= batchSize): # If this batch is full, start from the next one + unlabelledBatchNo += 1 + unlabelledData = {} + except: # If no batches with runName, start from 0 + unlabelledBatchNo = 0 + unlabelledData = {} + + pbatch = 0 # Number of batches completed so far + result = {'found_newest': False} # Initialise result + while(not result['found_newest'] and (pbatch < nbatch or nbatch < 0)): # Keep going until we run out of messages in the stream, or until we hit our batch limit. If nbatch=-1 then process all messages + request = {} # Request to pull messages + request['narrow'] = [ # Filter search on... + {"operator": "sender", "operand": scanbotAddress}, # Sender being scanbot (scanbot's email address) + {"operator": "stream", "operand": scanbotStream}] # Stream being 'scanbot' + # {"operator": "topic", "operand": "survey"}, # Topic being 'survey' + request['anchor'] = lastMsgID # Start search from this message ID. + request['num_before'] = 0 # 0 messages before lastMsgID + request['num_after'] = 100 # Grab the 100 messages after lastMsgID + result = client.get_messages(request) # Perform search + + messages = result['messages'] # All messages returned from search + + for message in messages: + if(pbatch == nbatch): break # Stop processing if we've hit our batch limit + if '.sxm' in message['content']: # If there's an sxm filename in the message content + try: + sxmFile = message['content'].split('.sxm')[0].split('/')[-1] + ".sxm" + if(sxmFile in labelledData): continue # Don't pickup the same sxm file twice + if(sxmFile in unlabelledData): continue # Don't pickup the same sxm file twice + try: - if not url.split('/scanbot/')[1].split('?')[0] in os.listdir(batch_path): - filename = wget.download(url=url, out=batch_path) - keyname = str(os.path.join(batch_path.split('/')[-1], filename.split('/')[-1])) - batch_labels[keyname] = labels + url = message['content'].split('')[0].replace('&', '&') except: - pass - - ## mark the message as read - to_mark_read.append(message['id']) + url = message['content'].split('(')[1].split(')')[0] - else: - keep_unread.append(message['id']) + labels = [] + for reaction in message['reactions']: + if reaction['user']['email'] in userList: # Only look at labels from users in the list + label = reaction['emoji_name'] + if(label not in labelDict): continue # Only process labels listed in the config file + if(labelDict[label] in labels): continue # Only add the label if it (or an equivalent one) hasn't been added yet + labels.append(labelDict[label]) # Append the label to the list for this sxm - else: ## message doesn't have an image in it; mark read - to_mark_read.append(message['id']) - - else: - print(results) - - ## dump the batch_labels - pickle.dump(batch_labels, open(os.path.join(batch_path, 'file_labels.pkl'), 'wb')) + if(len(labels)): + filename = wget.download(url=url) + im = iio.imread(filename) + os.remove(filename) + labelledData[sxmFile] = [im,labels] # Only data with more than zero labels goes in this list + if(len(labelledData) == batchSize): + print("Batch " + str(pbatch) + " complete") + pklParams['data'] = labelledData + pklName = 'zulipData-' + runName + '-labelled-' + pklName += 'batch_' + str(labelledBatchNo) + '.pkl' + pickle.dump(pklParams, open(pklPath + 'labelled/' + pklName, 'wb')) # Pickle containing config settings and labelled data + labelledBatchNo += 1 + labelledData = {} + pbatch += 1 + # else: + # unlabelledData[sxmFile] = [im,labels] # Unlabelled data goes in this list + # if(len(unlabelledData) == batchSize): + # pklParams['data'] = unlabelledData + # pklName = 'zulipData-' + runName + '-unlabelled-' + # pklName += 'batch_' + str(unlabelledBatchNo) + '.pkl' + # pickle.dump(pklParams, open(pklPath + 'unlabelled/' + pklName, 'wb')) # Pickle containing config settings and unlabelled data + # unlabelledBatchNo += 1 + # unlabelledData = {} + # pbatch += 1 + + lastMsgID = message['id'] + 1 # Remember this number for next time, so we don't need to go through entire message history + except: + pass + + if(len(labelledData)): + pklParams['data'] = labelledData + pklName = 'zulipData-' + runName + '-labelled-' + pklName += 'batch_' + str(labelledBatchNo) + '.pkl' + pickle.dump(pklParams, open(pklPath + 'labelled/' + pklName, 'wb')) # Pickle containing config settings and labelled data + # if(len(unlabelledData)): + # pklParams['data'] = unlabelledData + # pklName = 'zulipData-' + runName + '-unlabelled-' + # pklName += 'batch_' + str(unlabelledBatchNo) + '.pkl' + # pickle.dump(pklParams, open(pklPath + 'unlabelled/' + pklName, 'wb')) # Pickle containing config settings and unlabelled data + # unlabelledData = {} + with open('config.ini','r+') as f: + config = str(f.read()) + oldMsgID = config.split('last_message_id=')[1].split('\n')[0] + config = config.replace('last_message_id=' + str(oldMsgID), + 'last_message_id=' + str(lastMsgID)) -## now mark all read -if len(to_mark_read) > 0: - mkrd_request = { - 'messages': to_mark_read, - 'op': 'add', - 'flag': 'read', - } - result = client.update_message_flags(mkrd_request) - print(result) - - # ## put a sleep in here to not hit the API rate limit - # print('the slow loop to add eyes to all mark-read files') - # for msg_id in to_mark_read: - # react_request = { - # 'message_id': msg_id, - # 'emoji_name': 'eyes', - # } - # result = client.add_reaction(react_request) - # time.sleep(.3) - -## mark messages that haven't been labelled unread -mkunread_request = { - 'messages': keep_unread, - 'op': 'remove', - 'flag': 'read', - } -result = client.update_message_flags(mkunread_request) -print(result) - - - - \ No newline at end of file + with open('config.ini','w') as f: + f.write(config) # Update config.ini with lastMsgID to pick up from where we left off \ No newline at end of file diff --git a/tf_custom_metric.py b/tf_custom_metric.py deleted file mode 100644 index ce31dad..0000000 --- a/tf_custom_metric.py +++ /dev/null @@ -1,33 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -Created on Fri May 20 11:30:25 2022 - -custom metric to supercede @tf.function - -@author: jack -""" - -import tensorflow as tf - -@tf.function -def macro_soft_f1(y, y_hat): - """Compute the macro soft F1-score as a cost (average 1 - soft-F1 across all labels). - Use probability values instead of binary predictions. - - Args: - y (int32 Tensor): targets array of shape (BATCH_SIZE, N_LABELS) - y_hat (float32 Tensor): probability matrix from forward propagation of shape (BATCH_SIZE, N_LABELS) - - Returns: - cost (scalar Tensor): value of the cost function for the batch - """ - y = tf.cast(y, tf.float32) - y_hat = tf.cast(y_hat, tf.float32) - tp = tf.reduce_sum(y_hat * y, axis=0) - fp = tf.reduce_sum(y_hat * (1 - y), axis=0) - fn = tf.reduce_sum((1 - y_hat) * y, axis=0) - soft_f1 = 2*tp / (2*tp + fn + fp + 1e-16) - cost = 1 - soft_f1 # reduce 1 - soft-f1 in order to increase soft-f1 - macro_cost = tf.reduce_mean(cost) # average on all labels - return macro_cost \ No newline at end of file diff --git a/train_model.py b/train_model.py index 814fde7..2372a41 100755 --- a/train_model.py +++ b/train_model.py @@ -1,189 +1,179 @@ -#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ -Created on Wed May 18 13:30:07 2022 +Created on Mon Jul 18 14:03:49 2022 -https://www.tensorflow.org/tutorials/load_data/images - -@author: jack +@author: Maxwell West & Julian Ceddia """ -from datetime import datetime -import pickle -# import numpy as np import os -import sys -# import PIL -# import PIL.Image -import tensorflow as tf -from tensorflow.keras import layers -import tensorflow_hub as hub -# import tensorflow_datasets as tfds - -import zulip - - -# import matplotlib -# matplotlib.use('Agg') ## for plotting headless -# import matplotlib.pyplot as plt - -from sklearn.preprocessing import MultiLabelBinarizer -from sklearn.model_selection import train_test_split - -# import pathlib - -with open('config.ini', 'r') as f: - config = f.read() - data_dir = config.split('image_database_directory=')[1].split('\n')[0] - -if not pickle.load(open('retrain_flag.pkl', 'rb')): - print('retrain flag not set') - sys.exit() - -master_label_dict = {} -for root, dirs, files in os.walk(data_dir): - for name in files: - if name == 'file_labels.pkl': - master_label_dict.update(pickle.load(open(os.path.join(root, name), 'rb'))) - -## make ordered lists of the dict keys and values -data_files = [] -data_labels = [] -for key, value in master_label_dict.items(): - data_files.append(os.path.join(data_dir, key)) - data_labels.append(value) +import pickle +import torch, numpy as np +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from collections import Counter + +class ConvNet(nn.Module): + def __init__(self, load_model=""): + """ + Simple convolutional neural network with two convolutional layers, one + hidden layer, and one binary output layer. + + Parameters + ---------- + load_model : Path to the model to be loaded + + """ + super(ConvNet, self).__init__() + """ + /Replace this code with best architecture + """ + self.conv1 = nn.Conv2d( 3, 5, 3) # 3 input channels, 5 output channels and a kernel size of 3 + self.conv2 = nn.Conv2d( 5, 5, 3) # 5 input channels, 5 output channels and a kernel size of 3 + + self.pool = nn.MaxPool2d(2, 2) # https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html : Applies a 2D max pooling over an input signal composed of several input planes. + self.fc1 = nn.Linear(14045, 128) # 14,045 just comes from the amount of neurons the convolutional layers happened to finish with. 128 output neurons in this hidden layer is completely arbitrary + self.fc3 = nn.Linear(128, 2) # Output layer with binary classification to begin with + + self.opt = optim.Adam(self.parameters(), lr=0.001) # lr is another thing which can be played with... https://pythonguides.com/adam-optimizer-pytorch/ -for image_index, image_labels in enumerate(data_labels): - ## turn thumbs_down into -1: - image_labels = ['-1' if item == 'thumbs_down' else item for item in image_labels] - ## turn barber into striped_pole - image_labels = ['striped_pole' if item == 'barber' else item for item in image_labels] - - data_labels[image_index] = image_labels + """ + Replace this code with best architecture/ + """ + + if load_model: + load_model = "saved_nets/" + load_model + ".pt" + ckpt = torch.load(load_model) -X_train, X_val, y_train, y_val = train_test_split(data_files, data_labels, test_size=0.2, random_state=44) - - -## from https://github.com/ashrefm/multi-label-soft-f1/blob/master/Multi-Label%20Image%20Classification%20in%20TensorFlow%202.0.ipynb -mlb = MultiLabelBinarizer() -mlb.fit(y_train) -N_LABELS = len(mlb.classes_) - -# # Loop over all labels and show them -# N_LABELS = len(mlb.classes_) -# for (i, label) in enumerate(mlb.classes_): -# print("{}. {}".format(i, label)) - -y_train_bin = mlb.transform(y_train) -y_val_bin = mlb.transform(y_val) - - -IMG_SIZE = 224 # Specify height and width of image to match the input format of the model -CHANNELS = 3 # Keep RGB color channels to match the input format of the model - -def parse_function(filename, label): - """Function that returns a tuple of normalized image array and labels array. - Args: - filename: string representing path to image - label: 0/1 one-dimensional array of size N_LABELS - """ - # Read an image from a file - image_string = tf.io.read_file(filename) - # Decode it into a dense vector - image_decoded = tf.image.decode_jpeg(image_string, channels=CHANNELS) - # Resize it to fixed shape - image_resized = tf.image.resize(image_decoded, [IMG_SIZE, IMG_SIZE]) - # Normalize it from [0, 255] to [0.0, 1.0] - image_normalized = image_resized / 255.0 - return image_normalized, label - -BATCH_SIZE = 256 # Big enough to measure an F1-score -AUTOTUNE = tf.data.experimental.AUTOTUNE # Adapt preprocessing and prefetching dynamically -SHUFFLE_BUFFER_SIZE = 25 # Shuffle the training data by a chunk of 1024 observations - -def create_dataset(filenames, labels, is_training=True): - """Load and parse dataset. - Args: - filenames: list of image paths - labels: numpy array of shape (BATCH_SIZE, N_LABELS) - is_training: boolean to indicate training mode - """ + if "state_dict" in ckpt.keys(): + self.load_state_dict(ckpt['state_dict']) + + else: + self.load_state_dict(ckpt) + + def forward(self, x): + """ + Function that's called when a prediction is to be made. call like: + net = ConvNet()... + ... + prediction = net(data) + + Parameters + ---------- + x : Data/image to predict on + + Returns + ------- + x : Prediction tensor where each element is a label. Take the highest + + """ + """ + /Replace this code with the best architecture + """ + x = self.pool(F.relu(self.conv1(x))) # The first convolutional layer + x = self.pool(F.relu(self.conv2(x))) # Second convolutional layer + + x = x.view(x.size(0), -1) # Returns a new tensor with the same data as the x-tensor but of a different shape. + x = F.relu(self.fc1(x)) # Hidden layer + x = self.fc3(x) # Output layer + """ + Replace this code with the best architecture/ + """ + return x + + def train(self, x_train, y_train, x_test, y_test, name, epochs=1, batch_size=64): + criterion = nn.CrossEntropyLoss() + best = 0.0 # Keep track of the best performing model + print('Start training...') + print('------------------------------------------------') + print(' Train Acc | Test Acc | Best Test Acc | Loss') + print('------------------------------------------------') + for epoch in range(epochs): + running_loss = 0.0 + for i in range(x_train.size(0) // batch_size): + inputs = x_train[i * batch_size : (i+1) * batch_size] + labels = y_train[i * batch_size : (i+1) * batch_size] + + self.opt.zero_grad() - # Create a first dataset of file paths and labels - dataset = tf.data.Dataset.from_tensor_slices((filenames, labels)) - # Parse and preprocess observations in parallel - dataset = dataset.map(parse_function, num_parallel_calls=AUTOTUNE) + outputs = self(inputs) + loss = criterion(outputs, labels) + loss.backward() + self.opt.step() - if is_training == True: - # This is a small dataset, only load it once, and keep it in memory. - dataset = dataset.cache() - # Shuffle the data each buffer size - dataset = dataset.shuffle(buffer_size=SHUFFLE_BUFFER_SIZE) + running_loss += float(loss.detach()) + + if i and not(i % 10) or True: + + with torch.no_grad(): + + train_pred = self(x_train) + train_acc = (torch.sum(torch.argmax(train_pred, axis=1) == y_train) / y_train.size(0)).item() + + test_pred = self(x_test) + test_acc = (torch.sum(torch.argmax(test_pred, axis=1) == y_test) / y_test.size(0)).item() + + + if test_acc > best: + best = test_acc + torch.save(self.state_dict(), "saved_nets/" + name + ".pt") + + print(f' {train_acc:.3f} | {test_acc:.3f} | {best:.3f} | {running_loss:.3f}') + + running_loss = 0.0 + + print('Done training') + return name + +def trainNewCNN(runName, targetLabel, pklPath, augmentData): + x = [] # The images will go here + y = [] # The labels will go here + allLabels = [] # This will keep count of all the labels we see + if(not pklPath.endswith('/')): pklPath += '/' + pklPath += runName + "/labelled/" + pklFiles = os.listdir(pklPath) + for pklFile in pklFiles: + batchDict = pickle.load(open(pklPath + pklFile,'rb')) + batchData = batchDict['data'] + print(pklFile) + for key, value in batchData.items(): + im = np.array(value[0]/np.max(value[0]),dtype=np.float32) # Normalise the data and force to be float32 + labels = value[1] + if(im.shape != (221, 221, 4)): continue # Skip images that are the wrong size + x.append(np.transpose(im[:,:,:3], (2,0,1))) # Append the image to x + y.append(int(targetLabel in labels)) # At the moment I'm assuming the labels are just 0 or 1 depending on whether the target label is present + + if augmentData: # Optionally add images which are just reflections of existing images + x.append(np.transpose(im[:,::-1,:3], (2,0,1))) # Reflect in x + y.append(int(targetLabel in labels)) + + x.append(np.transpose(im[::-1,:,:3], (2,0,1))) # Reflect in y + y.append(int(targetLabel in labels)) + + x.append(np.transpose(im[::-1,::-1,:3], (2,0,1))) # Reflect in xy + y.append(int(targetLabel in labels)) + + allLabels.extend(labels * (4 if augmentData else 1)) # Keeping count of all labels we've seen - # Batch the data for multiple steps - dataset = dataset.batch(BATCH_SIZE) - # Fetch batches in the background while the model is training. - dataset = dataset.prefetch(buffer_size=AUTOTUNE) + x = np.array(x) # Convert to numpy array + y = np.array(y) # Convert to numpy array - return dataset - - -train_ds = create_dataset(X_train, y_train_bin) -val_ds = create_dataset(X_val, y_val_bin) - - -### headless model - -model = tf.keras.Sequential() - -model.add(hub.KerasLayer("https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/feature_vector/5", - trainable=False)) - -model.add(layers.Dense(256, activation='relu', name='hidden_layer_1')) -model.add(layers.Dense(256, activation='relu', name='hidden_layer_2')) -model.add(layers.Dense(N_LABELS, activation='sigmoid', name='output')) - -model.build([None, IMG_SIZE, IMG_SIZE, CHANNELS]) - -model.summary() - -from tf_custom_metric import macro_soft_f1 - -LR = 1e-5 # keep it small when transfer learning -EPOCHS = 30 - -model.compile( - optimizer=tf.keras.optimizers.Adam(learning_rate=LR), - loss=macro_soft_f1, - metrics=[macro_soft_f1]) - -start = datetime.now() -history = model.fit(train_ds, - epochs=EPOCHS, - validation_data=create_dataset(X_val, y_val_bin)) -print('\nTraining took {}'.format(datetime.now()-start)) - - - -model.save('kf_model.model') - -with open('class_names.pkl', 'wb') as f: - pickle.dump(mlb.classes_, f) + print('Total number of images: ', len(x)) + print('Summary of labels:') + label_freq = sorted(zip(Counter(allLabels).keys(), Counter(allLabels).values()), key=lambda x: -x[1]) + print(*label_freq,sep='\n') -print(mlb.classes_) - -## set training flag back to false -pickle.dump(False, open('retrain_flag.pkl', 'wb')) - -## zulip message saying the training is done -client = zulip.Client(config_file='zuliprc') - -request = { - "type": "stream", - "to": "scanbot", - "topic": "TF model", - "content": "train_model finished" - } -result = client.send_message(request) -print(result) + shuffle = np.random.permutation(range(x.shape[0])) # Shuffle the data + + x = torch.tensor(x[shuffle]) # Convert to a pytorch compatible tensor + y = torch.tensor(y[shuffle], dtype=torch.long) + + cutoff = int(0.9 * x.shape[0]) # Split into training and testing/validation subsets + x_train, x_test = x[:cutoff], x[cutoff:] # Image data + y_train, y_test = y[:cutoff], y[cutoff:] # Labels + + net = ConvNet() # Instantiate the CNN + modelPath = net.train(x_train, y_train, x_test, y_test, name=runName, epochs=2, batch_size=64) # Train the CNN + + return modelPath \ No newline at end of file