+
+@author: Julian Ceddia & Jack Hellerstedt
"""
-import os
-# import time
import zulip
-import wget
import pickle
-
-
-# make database files
-with open('config.ini', 'r') as f:
- config = f.read()
- image_data_path = config.split('image_database_directory=')[1].split('\n')[0]
- zuliprc_name = config.split('zuliprc_name=')[1].split('\n')[0]
- classifybot_name = config.split('classifybot_name=')[1].split('\n')[0]
- scanbot_address = config.split('scanbot_address=')[1].split('\n')[0]
- scanbot_stream = config.split('scanbot_stream=')[1].split('\n')[0]
-
-try:
- os.mkdir(image_data_path)
-except:
- print('image_data dir already exists')
-
-if not 'batch_0' in os.listdir(image_data_path):
- os.mkdir(os.path.join(image_data_path, 'batch_0'))
-
-
-
-## specify zuliprc file
-zuliprc_path = os.path.join(os.getcwd(), zuliprc_name)
-client = zulip.Client(config_file=zuliprc_path)
-
-## specify hard-coded classifybot name
-# classifybot_name = 'classifybot'
-
-
-
-
-## specify stream & scanbot email address to read messages:
-request = {}
-## define the narrow
-request['narrow'] = [
- {"operator": "sender", "operand": scanbot_address},
- {"operator": "stream", "operand": scanbot_stream},
- # {"operator": "topic", "operand": "survey"},
- ]
-
-## get id of first unread message
-request['anchor'] = 'first_unread'
-request['num_before'] = 0
-request['num_after'] = 0
-
-##go for oldest
-# request['anchor'] = 'oldest'
-# request['num_before'] = 0
-# request['num_after'] = 1
-
-result = client.get_messages(request)
-print('first_unread anchor')
-print(result)
-
-first_unread_id = result['messages'][0]['id']
-
-## get id of newest message
-request['anchor'] = 'newest'
-request['num_before'] = 1
-request['num_after'] = 0
-
-result = client.get_messages(request)
-print('newest anchor')
-print(result)
-
-newest_message_id = result['messages'][0]['id']
-
-to_mark_read = []
-keep_unread = []
-for message_id in range(first_unread_id, newest_message_id, 100):
- request['anchor'] = message_id
- request['num_before'] = 0
- request['num_after'] = 100
+import wget
+import os
+import imageio as iio
+
+def getZulipData():
+ with open('config.ini','r') as f:
+ config = f.read() # config.ini must contain all of the following:
+ runName = config.split('run_name=')[1].split('\n')[0] # Name of the run. pkl saved as zulipData-runName-batch_x.pkl
+ zulipRcPath = config.split('zulip_rc_path=')[1].split('\n')[0] # Path to the zulip bot's rc file
+ scanbotAddress = config.split('scanbot_address=')[1].split('\n')[0] # scanbot's email address
+ scanbotStream = config.split('scanbot_stream=')[1].split('\n')[0] # stream to walk through
+ lastMsgID = int(config.split('last_message_id=')[1].split('\n')[0]) # Start searching from this message ID. Doesn't have to be an exact message ID match. 0 starts from begining. Autoupdated at the end of a run to pick up from where you left off
+ batchSize = int(config.split('batch_size=')[1].split('\n')[0]) # Number of images per pkl file
+ nbatch = int(config.split('nbatch=')[1].split('\n')[0]) # Number of batches to process -1 for entire stream. 1 batch per pkl
+ userList = config.split('user_list=')[1].split('\n')[0].split(',') # Users to read labels from
+ labelDictIni = config.split('label_dict=')[1].split('\n')[0] # Map emojis to labels... In the form of emoji1:label1,emoji2:label2. useful for same emoji with multiple names (e.g. -1 and thumbs_down)
+ pklPath = config.split('pkl_path=')[1].split('\n')[0] # Output path for pkl'd data
- ## check which batch folder to put images in
- folders = os.listdir(image_data_path)
- max_batch_index = 0
- for name in folders:
- if name.split('_')[0] == 'batch':
- batch_index = int(name.split('_')[1])
- if batch_index > max_batch_index:
- max_batch_index = batch_index
+ labelDict = {} # Convert labelDictini to python dictionary...
+ labelDictIni = labelDictIni.split(',')
+ for ld in labelDictIni:
+ key,value = ld.split(':')
+ labelDict[key] = value
- batch_path = os.path.join(image_data_path, 'batch_' + str(max_batch_index))
+ if(not pklPath.endswith('/')): pklPath += '/'
+ pklPath += runName + '/'
+ pklParams = {"runName": runName, # Store config params in the final pickle file
+ "zulipRcPath": zulipRcPath,
+ "scanbotAddress": scanbotAddress,
+ "scanbotStream": scanbotStream,
+ "lastMsgID": lastMsgID,
+ "userList": userList,
+ "batchSize": batchSize,
+ "nbatch": nbatch,
+ "labelDict": labelDict}
- if len(os.listdir(batch_path)) > 256:
- batch_path = os.path.join(image_data_path, 'batch_' + str(max_batch_index+1))
- os.mkdir(batch_path)
- pickle.dump(True, open('retrain_flag.pkl', 'wb'))
+ client = zulip.Client(config_file=zulipRcPath) # Zulip Client
+ handle = client.get_profile()['full_name'] # Bot's handle
+ try: os.mkdir(pklPath)
+ except: pass
+ try: os.mkdir(pklPath + "labelled")
+ except: pass
+ try: os.mkdir(pklPath + "unlabelled")
+ except: pass
try:
- batch_labels = pickle.load(open(os.path.join(batch_path,'file_labels.pkl'), 'rb'))
- except:
- print('no batch labels ' + batch_path)
- batch_labels = {}
+ labelledBatchNo = max([int(b.split('batch_')[1].split('.pkl')[0]) for b in # If there are any pkl's with this runName, check what batch number we're up to
+ [f for f in os.listdir(pklPath + 'labelled/')
+ if('zulipData-' + runName + '-labelled-' in f)]])
+ labelledData = pickle.load(open(pklPath + 'labelled/zulipData-' + runName + '-labelled-batch_' # load in the batch file to continue adding to it
+ + str(labelledBatchNo) + '.pkl','rb'))
+ labelledData = labelledData['data'] # just get the data from it
+ if(len(labelledData) >= batchSize): # If this batch is full, start from the next one
+ labelledBatchNo += 1
+ labelledData = {}
+ except: # If no batches with runName, start from 0
+ labelledBatchNo = 0
+ labelledData = {}
-
- results = client.get_messages(request)
- if results['result'] == 'success':
- for message in results['messages']:
- if '' in message['content'] and 'read' not in message['flags']:
- url = message['content'].split('
')[0].replace('&', '&')
- labels = []
- for reaction in message['reactions']:
- if reaction['user']['full_name'] != classifybot_name:
- labels.append(reaction['emoji_name'])
-
- if len(labels) > 0 and '.png' in url:
+ try:
+ unlabelledBatchNo = max([int(b.split('batch_')[1].split('.pkl')[0]) for b in# If there are any pkl's with this runName, check what batch number we're up to
+ [f for f in os.listdir(pklPath + 'unlabelled/')
+ if('zulipData-' + runName + '-unlabelled-' in f)]])
+ unlabelledData = pickle.load(open(pklPath + 'unlabelled/zulipData-' + runName + '-unlabelled-batch_' # load in the batch file to continue adding to it
+ + str(unlabelledBatchNo) + '.pkl','rb'))
+ unlabelledData = unlabelledData['data'] # just get the data from it
+ if(len(unlabelledData) >= batchSize): # If this batch is full, start from the next one
+ unlabelledBatchNo += 1
+ unlabelledData = {}
+ except: # If no batches with runName, start from 0
+ unlabelledBatchNo = 0
+ unlabelledData = {}
+
+ pbatch = 0 # Number of batches completed so far
+ result = {'found_newest': False} # Initialise result
+ while(not result['found_newest'] and (pbatch < nbatch or nbatch < 0)): # Keep going until we run out of messages in the stream, or until we hit our batch limit. If nbatch=-1 then process all messages
+ request = {} # Request to pull messages
+ request['narrow'] = [ # Filter search on...
+ {"operator": "sender", "operand": scanbotAddress}, # Sender being scanbot (scanbot's email address)
+ {"operator": "stream", "operand": scanbotStream}] # Stream being 'scanbot'
+ # {"operator": "topic", "operand": "survey"}, # Topic being 'survey'
+ request['anchor'] = lastMsgID # Start search from this message ID.
+ request['num_before'] = 0 # 0 messages before lastMsgID
+ request['num_after'] = 100 # Grab the 100 messages after lastMsgID
+ result = client.get_messages(request) # Perform search
+
+ messages = result['messages'] # All messages returned from search
+
+ for message in messages:
+ if(pbatch == nbatch): break # Stop processing if we've hit our batch limit
+ if '.sxm' in message['content']: # If there's an sxm filename in the message content
+ try:
+ sxmFile = message['content'].split('.sxm')[0].split('/')[-1] + ".sxm"
+ if(sxmFile in labelledData): continue # Don't pickup the same sxm file twice
+ if(sxmFile in unlabelledData): continue # Don't pickup the same sxm file twice
+
try:
- if not url.split('/scanbot/')[1].split('?')[0] in os.listdir(batch_path):
- filename = wget.download(url=url, out=batch_path)
- keyname = str(os.path.join(batch_path.split('/')[-1], filename.split('/')[-1]))
- batch_labels[keyname] = labels
+ url = message['content'].split('')[0].replace('&', '&')
except:
- pass
-
- ## mark the message as read
- to_mark_read.append(message['id'])
+ url = message['content'].split('(')[1].split(')')[0]
- else:
- keep_unread.append(message['id'])
+ labels = []
+ for reaction in message['reactions']:
+ if reaction['user']['email'] in userList: # Only look at labels from users in the list
+ label = reaction['emoji_name']
+ if(label not in labelDict): continue # Only process labels listed in the config file
+ if(labelDict[label] in labels): continue # Only add the label if it (or an equivalent one) hasn't been added yet
+ labels.append(labelDict[label]) # Append the label to the list for this sxm
- else: ## message doesn't have an image in it; mark read
- to_mark_read.append(message['id'])
-
- else:
- print(results)
-
- ## dump the batch_labels
- pickle.dump(batch_labels, open(os.path.join(batch_path, 'file_labels.pkl'), 'wb'))
+ if(len(labels)):
+ filename = wget.download(url=url)
+ im = iio.imread(filename)
+ os.remove(filename)
+ labelledData[sxmFile] = [im,labels] # Only data with more than zero labels goes in this list
+ if(len(labelledData) == batchSize):
+ print("Batch " + str(pbatch) + " complete")
+ pklParams['data'] = labelledData
+ pklName = 'zulipData-' + runName + '-labelled-'
+ pklName += 'batch_' + str(labelledBatchNo) + '.pkl'
+ pickle.dump(pklParams, open(pklPath + 'labelled/' + pklName, 'wb')) # Pickle containing config settings and labelled data
+ labelledBatchNo += 1
+ labelledData = {}
+ pbatch += 1
+ # else:
+ # unlabelledData[sxmFile] = [im,labels] # Unlabelled data goes in this list
+ # if(len(unlabelledData) == batchSize):
+ # pklParams['data'] = unlabelledData
+ # pklName = 'zulipData-' + runName + '-unlabelled-'
+ # pklName += 'batch_' + str(unlabelledBatchNo) + '.pkl'
+ # pickle.dump(pklParams, open(pklPath + 'unlabelled/' + pklName, 'wb')) # Pickle containing config settings and unlabelled data
+ # unlabelledBatchNo += 1
+ # unlabelledData = {}
+ # pbatch += 1
+
+ lastMsgID = message['id'] + 1 # Remember this number for next time, so we don't need to go through entire message history
+ except:
+ pass
+
+ if(len(labelledData)):
+ pklParams['data'] = labelledData
+ pklName = 'zulipData-' + runName + '-labelled-'
+ pklName += 'batch_' + str(labelledBatchNo) + '.pkl'
+ pickle.dump(pklParams, open(pklPath + 'labelled/' + pklName, 'wb')) # Pickle containing config settings and labelled data
+ # if(len(unlabelledData)):
+ # pklParams['data'] = unlabelledData
+ # pklName = 'zulipData-' + runName + '-unlabelled-'
+ # pklName += 'batch_' + str(unlabelledBatchNo) + '.pkl'
+ # pickle.dump(pklParams, open(pklPath + 'unlabelled/' + pklName, 'wb')) # Pickle containing config settings and unlabelled data
+ # unlabelledData = {}
+ with open('config.ini','r+') as f:
+ config = str(f.read())
+ oldMsgID = config.split('last_message_id=')[1].split('\n')[0]
+ config = config.replace('last_message_id=' + str(oldMsgID),
+ 'last_message_id=' + str(lastMsgID))
-## now mark all read
-if len(to_mark_read) > 0:
- mkrd_request = {
- 'messages': to_mark_read,
- 'op': 'add',
- 'flag': 'read',
- }
- result = client.update_message_flags(mkrd_request)
- print(result)
-
- # ## put a sleep in here to not hit the API rate limit
- # print('the slow loop to add eyes to all mark-read files')
- # for msg_id in to_mark_read:
- # react_request = {
- # 'message_id': msg_id,
- # 'emoji_name': 'eyes',
- # }
- # result = client.add_reaction(react_request)
- # time.sleep(.3)
-
-## mark messages that haven't been labelled unread
-mkunread_request = {
- 'messages': keep_unread,
- 'op': 'remove',
- 'flag': 'read',
- }
-result = client.update_message_flags(mkunread_request)
-print(result)
-
-
-
-
\ No newline at end of file
+ with open('config.ini','w') as f:
+ f.write(config) # Update config.ini with lastMsgID to pick up from where we left off
\ No newline at end of file
diff --git a/tf_custom_metric.py b/tf_custom_metric.py
deleted file mode 100644
index ce31dad..0000000
--- a/tf_custom_metric.py
+++ /dev/null
@@ -1,33 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-"""
-Created on Fri May 20 11:30:25 2022
-
-custom metric to supercede @tf.function
-
-@author: jack
-"""
-
-import tensorflow as tf
-
-@tf.function
-def macro_soft_f1(y, y_hat):
- """Compute the macro soft F1-score as a cost (average 1 - soft-F1 across all labels).
- Use probability values instead of binary predictions.
-
- Args:
- y (int32 Tensor): targets array of shape (BATCH_SIZE, N_LABELS)
- y_hat (float32 Tensor): probability matrix from forward propagation of shape (BATCH_SIZE, N_LABELS)
-
- Returns:
- cost (scalar Tensor): value of the cost function for the batch
- """
- y = tf.cast(y, tf.float32)
- y_hat = tf.cast(y_hat, tf.float32)
- tp = tf.reduce_sum(y_hat * y, axis=0)
- fp = tf.reduce_sum(y_hat * (1 - y), axis=0)
- fn = tf.reduce_sum((1 - y_hat) * y, axis=0)
- soft_f1 = 2*tp / (2*tp + fn + fp + 1e-16)
- cost = 1 - soft_f1 # reduce 1 - soft-f1 in order to increase soft-f1
- macro_cost = tf.reduce_mean(cost) # average on all labels
- return macro_cost
\ No newline at end of file
diff --git a/train_model.py b/train_model.py
index 814fde7..2372a41 100755
--- a/train_model.py
+++ b/train_model.py
@@ -1,189 +1,179 @@
-#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
-Created on Wed May 18 13:30:07 2022
+Created on Mon Jul 18 14:03:49 2022
-https://www.tensorflow.org/tutorials/load_data/images
-
-@author: jack
+@author: Maxwell West & Julian Ceddia
"""
-from datetime import datetime
-import pickle
-# import numpy as np
import os
-import sys
-# import PIL
-# import PIL.Image
-import tensorflow as tf
-from tensorflow.keras import layers
-import tensorflow_hub as hub
-# import tensorflow_datasets as tfds
-
-import zulip
-
-
-# import matplotlib
-# matplotlib.use('Agg') ## for plotting headless
-# import matplotlib.pyplot as plt
-
-from sklearn.preprocessing import MultiLabelBinarizer
-from sklearn.model_selection import train_test_split
-
-# import pathlib
-
-with open('config.ini', 'r') as f:
- config = f.read()
- data_dir = config.split('image_database_directory=')[1].split('\n')[0]
-
-if not pickle.load(open('retrain_flag.pkl', 'rb')):
- print('retrain flag not set')
- sys.exit()
-
-master_label_dict = {}
-for root, dirs, files in os.walk(data_dir):
- for name in files:
- if name == 'file_labels.pkl':
- master_label_dict.update(pickle.load(open(os.path.join(root, name), 'rb')))
-
-## make ordered lists of the dict keys and values
-data_files = []
-data_labels = []
-for key, value in master_label_dict.items():
- data_files.append(os.path.join(data_dir, key))
- data_labels.append(value)
+import pickle
+import torch, numpy as np
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from collections import Counter
+
+class ConvNet(nn.Module):
+ def __init__(self, load_model=""):
+ """
+ Simple convolutional neural network with two convolutional layers, one
+ hidden layer, and one binary output layer.
+
+ Parameters
+ ----------
+ load_model : Path to the model to be loaded
+
+ """
+ super(ConvNet, self).__init__()
+ """
+ /Replace this code with best architecture
+ """
+ self.conv1 = nn.Conv2d( 3, 5, 3) # 3 input channels, 5 output channels and a kernel size of 3
+ self.conv2 = nn.Conv2d( 5, 5, 3) # 5 input channels, 5 output channels and a kernel size of 3
+
+ self.pool = nn.MaxPool2d(2, 2) # https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html : Applies a 2D max pooling over an input signal composed of several input planes.
+ self.fc1 = nn.Linear(14045, 128) # 14,045 just comes from the amount of neurons the convolutional layers happened to finish with. 128 output neurons in this hidden layer is completely arbitrary
+ self.fc3 = nn.Linear(128, 2) # Output layer with binary classification to begin with
+
+ self.opt = optim.Adam(self.parameters(), lr=0.001) # lr is another thing which can be played with... https://pythonguides.com/adam-optimizer-pytorch/
-for image_index, image_labels in enumerate(data_labels):
- ## turn thumbs_down into -1:
- image_labels = ['-1' if item == 'thumbs_down' else item for item in image_labels]
- ## turn barber into striped_pole
- image_labels = ['striped_pole' if item == 'barber' else item for item in image_labels]
-
- data_labels[image_index] = image_labels
+ """
+ Replace this code with best architecture/
+ """
+
+ if load_model:
+ load_model = "saved_nets/" + load_model + ".pt"
+ ckpt = torch.load(load_model)
-X_train, X_val, y_train, y_val = train_test_split(data_files, data_labels, test_size=0.2, random_state=44)
-
-
-## from https://github.com/ashrefm/multi-label-soft-f1/blob/master/Multi-Label%20Image%20Classification%20in%20TensorFlow%202.0.ipynb
-mlb = MultiLabelBinarizer()
-mlb.fit(y_train)
-N_LABELS = len(mlb.classes_)
-
-# # Loop over all labels and show them
-# N_LABELS = len(mlb.classes_)
-# for (i, label) in enumerate(mlb.classes_):
-# print("{}. {}".format(i, label))
-
-y_train_bin = mlb.transform(y_train)
-y_val_bin = mlb.transform(y_val)
-
-
-IMG_SIZE = 224 # Specify height and width of image to match the input format of the model
-CHANNELS = 3 # Keep RGB color channels to match the input format of the model
-
-def parse_function(filename, label):
- """Function that returns a tuple of normalized image array and labels array.
- Args:
- filename: string representing path to image
- label: 0/1 one-dimensional array of size N_LABELS
- """
- # Read an image from a file
- image_string = tf.io.read_file(filename)
- # Decode it into a dense vector
- image_decoded = tf.image.decode_jpeg(image_string, channels=CHANNELS)
- # Resize it to fixed shape
- image_resized = tf.image.resize(image_decoded, [IMG_SIZE, IMG_SIZE])
- # Normalize it from [0, 255] to [0.0, 1.0]
- image_normalized = image_resized / 255.0
- return image_normalized, label
-
-BATCH_SIZE = 256 # Big enough to measure an F1-score
-AUTOTUNE = tf.data.experimental.AUTOTUNE # Adapt preprocessing and prefetching dynamically
-SHUFFLE_BUFFER_SIZE = 25 # Shuffle the training data by a chunk of 1024 observations
-
-def create_dataset(filenames, labels, is_training=True):
- """Load and parse dataset.
- Args:
- filenames: list of image paths
- labels: numpy array of shape (BATCH_SIZE, N_LABELS)
- is_training: boolean to indicate training mode
- """
+ if "state_dict" in ckpt.keys():
+ self.load_state_dict(ckpt['state_dict'])
+
+ else:
+ self.load_state_dict(ckpt)
+
+ def forward(self, x):
+ """
+ Function that's called when a prediction is to be made. call like:
+ net = ConvNet()...
+ ...
+ prediction = net(data)
+
+ Parameters
+ ----------
+ x : Data/image to predict on
+
+ Returns
+ -------
+ x : Prediction tensor where each element is a label. Take the highest
+
+ """
+ """
+ /Replace this code with the best architecture
+ """
+ x = self.pool(F.relu(self.conv1(x))) # The first convolutional layer
+ x = self.pool(F.relu(self.conv2(x))) # Second convolutional layer
+
+ x = x.view(x.size(0), -1) # Returns a new tensor with the same data as the x-tensor but of a different shape.
+ x = F.relu(self.fc1(x)) # Hidden layer
+ x = self.fc3(x) # Output layer
+ """
+ Replace this code with the best architecture/
+ """
+ return x
+
+ def train(self, x_train, y_train, x_test, y_test, name, epochs=1, batch_size=64):
+ criterion = nn.CrossEntropyLoss()
+ best = 0.0 # Keep track of the best performing model
+ print('Start training...')
+ print('------------------------------------------------')
+ print(' Train Acc | Test Acc | Best Test Acc | Loss')
+ print('------------------------------------------------')
+ for epoch in range(epochs):
+ running_loss = 0.0
+ for i in range(x_train.size(0) // batch_size):
+ inputs = x_train[i * batch_size : (i+1) * batch_size]
+ labels = y_train[i * batch_size : (i+1) * batch_size]
+
+ self.opt.zero_grad()
- # Create a first dataset of file paths and labels
- dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
- # Parse and preprocess observations in parallel
- dataset = dataset.map(parse_function, num_parallel_calls=AUTOTUNE)
+ outputs = self(inputs)
+ loss = criterion(outputs, labels)
+ loss.backward()
+ self.opt.step()
- if is_training == True:
- # This is a small dataset, only load it once, and keep it in memory.
- dataset = dataset.cache()
- # Shuffle the data each buffer size
- dataset = dataset.shuffle(buffer_size=SHUFFLE_BUFFER_SIZE)
+ running_loss += float(loss.detach())
+
+ if i and not(i % 10) or True:
+
+ with torch.no_grad():
+
+ train_pred = self(x_train)
+ train_acc = (torch.sum(torch.argmax(train_pred, axis=1) == y_train) / y_train.size(0)).item()
+
+ test_pred = self(x_test)
+ test_acc = (torch.sum(torch.argmax(test_pred, axis=1) == y_test) / y_test.size(0)).item()
+
+
+ if test_acc > best:
+ best = test_acc
+ torch.save(self.state_dict(), "saved_nets/" + name + ".pt")
+
+ print(f' {train_acc:.3f} | {test_acc:.3f} | {best:.3f} | {running_loss:.3f}')
+
+ running_loss = 0.0
+
+ print('Done training')
+ return name
+
+def trainNewCNN(runName, targetLabel, pklPath, augmentData):
+ x = [] # The images will go here
+ y = [] # The labels will go here
+ allLabels = [] # This will keep count of all the labels we see
+ if(not pklPath.endswith('/')): pklPath += '/'
+ pklPath += runName + "/labelled/"
+ pklFiles = os.listdir(pklPath)
+ for pklFile in pklFiles:
+ batchDict = pickle.load(open(pklPath + pklFile,'rb'))
+ batchData = batchDict['data']
+ print(pklFile)
+ for key, value in batchData.items():
+ im = np.array(value[0]/np.max(value[0]),dtype=np.float32) # Normalise the data and force to be float32
+ labels = value[1]
+ if(im.shape != (221, 221, 4)): continue # Skip images that are the wrong size
+ x.append(np.transpose(im[:,:,:3], (2,0,1))) # Append the image to x
+ y.append(int(targetLabel in labels)) # At the moment I'm assuming the labels are just 0 or 1 depending on whether the target label is present
+
+ if augmentData: # Optionally add images which are just reflections of existing images
+ x.append(np.transpose(im[:,::-1,:3], (2,0,1))) # Reflect in x
+ y.append(int(targetLabel in labels))
+
+ x.append(np.transpose(im[::-1,:,:3], (2,0,1))) # Reflect in y
+ y.append(int(targetLabel in labels))
+
+ x.append(np.transpose(im[::-1,::-1,:3], (2,0,1))) # Reflect in xy
+ y.append(int(targetLabel in labels))
+
+ allLabels.extend(labels * (4 if augmentData else 1)) # Keeping count of all labels we've seen
- # Batch the data for multiple steps
- dataset = dataset.batch(BATCH_SIZE)
- # Fetch batches in the background while the model is training.
- dataset = dataset.prefetch(buffer_size=AUTOTUNE)
+ x = np.array(x) # Convert to numpy array
+ y = np.array(y) # Convert to numpy array
- return dataset
-
-
-train_ds = create_dataset(X_train, y_train_bin)
-val_ds = create_dataset(X_val, y_val_bin)
-
-
-### headless model
-
-model = tf.keras.Sequential()
-
-model.add(hub.KerasLayer("https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/feature_vector/5",
- trainable=False))
-
-model.add(layers.Dense(256, activation='relu', name='hidden_layer_1'))
-model.add(layers.Dense(256, activation='relu', name='hidden_layer_2'))
-model.add(layers.Dense(N_LABELS, activation='sigmoid', name='output'))
-
-model.build([None, IMG_SIZE, IMG_SIZE, CHANNELS])
-
-model.summary()
-
-from tf_custom_metric import macro_soft_f1
-
-LR = 1e-5 # keep it small when transfer learning
-EPOCHS = 30
-
-model.compile(
- optimizer=tf.keras.optimizers.Adam(learning_rate=LR),
- loss=macro_soft_f1,
- metrics=[macro_soft_f1])
-
-start = datetime.now()
-history = model.fit(train_ds,
- epochs=EPOCHS,
- validation_data=create_dataset(X_val, y_val_bin))
-print('\nTraining took {}'.format(datetime.now()-start))
-
-
-
-model.save('kf_model.model')
-
-with open('class_names.pkl', 'wb') as f:
- pickle.dump(mlb.classes_, f)
+ print('Total number of images: ', len(x))
+ print('Summary of labels:')
+ label_freq = sorted(zip(Counter(allLabels).keys(), Counter(allLabels).values()), key=lambda x: -x[1])
+ print(*label_freq,sep='\n')
-print(mlb.classes_)
-
-## set training flag back to false
-pickle.dump(False, open('retrain_flag.pkl', 'wb'))
-
-## zulip message saying the training is done
-client = zulip.Client(config_file='zuliprc')
-
-request = {
- "type": "stream",
- "to": "scanbot",
- "topic": "TF model",
- "content": "train_model finished"
- }
-result = client.send_message(request)
-print(result)
+ shuffle = np.random.permutation(range(x.shape[0])) # Shuffle the data
+
+ x = torch.tensor(x[shuffle]) # Convert to a pytorch compatible tensor
+ y = torch.tensor(y[shuffle], dtype=torch.long)
+
+ cutoff = int(0.9 * x.shape[0]) # Split into training and testing/validation subsets
+ x_train, x_test = x[:cutoff], x[cutoff:] # Image data
+ y_train, y_test = y[:cutoff], y[cutoff:] # Labels
+
+ net = ConvNet() # Instantiate the CNN
+ modelPath = net.train(x_train, y_train, x_test, y_test, name=runName, epochs=2, batch_size=64) # Train the CNN
+
+ return modelPath
\ No newline at end of file