diff --git a/GCNFrame/Biodata.py b/GCNFrame/Biodata.py index da14ebc..dd3f7dc 100644 --- a/GCNFrame/Biodata.py +++ b/GCNFrame/Biodata.py @@ -1,5 +1,7 @@ import numpy as np +import platform from Bio import SeqIO +from concurrent.futures import ThreadPoolExecutor from multiprocessing import Pool from functools import partial @@ -130,11 +132,17 @@ def __init__(self, fasta_file, label_file=None, feature_file=None, K=3, d=3, seq def encode(self, thread=10, save_dataset=True, save_path="./"): print("Encoding sequences...") seq_list = list(self.dna_seq.values()) - pool = Pool(thread) partial_encode_seq = partial(encode_seq.matrix_encoding, K=self.K, d=self.d, seqtype=self.seqtype) - feature = np.array(pool.map(partial_encode_seq, seq_list)) - pool.close() - pool.join() + + if platform.system() == 'Darwin': + with ThreadPoolExecutor(max_workers=thread) as executor: + feature = np.array(list(executor.map(partial_encode_seq, seq_list))) + else: + pool = Pool(thread) + feature = np.array(pool.map(partial_encode_seq, seq_list)) + pool.close() + pool.join() + self.pnode_feature = feature.reshape(-1, self.d, 4**(self.K*2)) self.pnode_feature = np.moveaxis(self.pnode_feature, 1, 2) zero_layer = feature.reshape(-1, self.d, 4**self.K, 4**self.K)[:, 0, :, :] @@ -143,7 +151,6 @@ def encode(self, thread=10, save_dataset=True, save_path="./"): if save_dataset: dataset = GraphDatasetInMem(self.pnode_feature, self.fnode_feature, self.other_feature, self.edge, self.label, root=save_path) - else: graph = GraphDataset(self.pnode_feature, self.fnode_feature, self.other_feature, self.edge, self.label) dataset = graph.process()