-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathFeatureModel.java
More file actions
252 lines (240 loc) · 7.29 KB
/
FeatureModel.java
File metadata and controls
252 lines (240 loc) · 7.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
import java.util.ArrayList;
import java.util.Comparator;
import java.util.Hashtable;
import java.util.PriorityQueue;
import java.util.Set;
class min implements Comparator{//comparator used for minheap
public min(){}
public int compare(Object o1, Object o2){
myString s1 = (myString) o1;
myString s2 = (myString) o2;
return s1.compareTo(s2);
}
}
public class FeatureModel{
public ArrayList<Integer> counts;//count(s)
public ArrayList<Double> probs;//prob(s)
public Hashtable<String,Integer> mostfrequent;//the most frequently word
public int k;//k surrounding words
public int num;//num of most frequently occurred words
public double threshold;//the threshold
public Hashtable<Integer, Hashtable<String, feature>> featurecount;
public int tp = 0;
public int fp = 0;
public int fn = 0;
/* featurecount:
* Integer : s
* String: feat; feature: count(f,s)
*/
public double score;
public int totalwords;
public FeatureModel(int kk, int nn,ArrayList<ArrayList<sentence>> t,ArrayList<Integer> tempcount){
k = kk;
counts = new ArrayList(tempcount);//initialize
probs = new ArrayList();//initialize
int sum = 0;
for(Integer temp:counts){
sum = sum+temp;
}
for(Integer temp: counts){
probs.add(1.0*temp/sum);//this is prob({s_i})
}
num = nn;
Hashtable<String, Integer> table = new Hashtable<String, Integer>();
//get a count of the word
for(ArrayList<sentence> s:t){
for(sentence ss:s){
ArrayList<String> words = ss.wordAroundIndex(kk);
for(String w: words){
int i;
if(table.containsKey(w)){
i = table.get(w);
i++;
}
else{
i = 1;
}
table.put(w, i);
}
}
}
//find out the most frequently words
Set<String> word = table.keySet();
PriorityQueue minHeap = new PriorityQueue(num,new min());
for(String w: word){
myString temps = new myString(w,table.get(w));
if(minHeap.size()<num){
minHeap.add(temps);
}
else if(temps.compareTo((myString)minHeap.peek())>0){
minHeap.poll();
minHeap.add(temps);
}
}
mostfrequent = new Hashtable();
while(minHeap.size()>0){
myString temp = (myString)minHeap.poll();
mostfrequent.put(temp.value(),temp.count);
}
//mostfrequent words finished
//after finding out the most frequently occurred words, we need to figure out count(f, s)
featurecount = new Hashtable();
int groupindex = 0;
for(ArrayList<sentence>s: t){
//for each group initialize the hashtable first
Hashtable<String, feature> tempfeature = new Hashtable();
for(String temps: mostfrequent.keySet()){
feature f = new feature(groupindex,temps);
tempfeature.put(temps, f);
}
for(sentence ss: s){
//for every sentence in s
ArrayList<String> groupword = new ArrayList(ss.wordAroundIndex(k));
for(String wordtemp:groupword){
if(tempfeature.containsKey(wordtemp)){
feature tempf = tempfeature.get(wordtemp);
tempf.addCount();
tempfeature.put(wordtemp, tempf);//add back
}
}
}
featurecount.put(groupindex, tempfeature);
groupindex++;//increment
}
}
public double runvalidation(ArrayList<sentence> v){
//step1: set up the threshold
int thresholdcount = 0;
for(sentence c:v){
ArrayList<String> word = c.wordAroundIndex(k);//get all the words
ArrayList<Integer> tempvalue = new ArrayList<Integer>(c.value);
//find the ones that are in the feature set
Hashtable<String, Integer> featureInAns = new Hashtable<String, Integer>();
for(String temp:word){
if(mostfrequent.containsKey(temp)){
featureInAns.put(temp, 1);
}
}
Set<Integer> myfeat = featurecount.keySet();
for(Integer temp: myfeat){
//check the prob in this category
//p = p(temp) * p(f|temp) = p(temp) * count(f_j,temp) / count(temp)
double ptemp = Math.log(probs.get((int)temp));
int count_temp = counts.get((int)temp);
ArrayList<Double> countf_temp = new ArrayList();
for(String s:mostfrequent.keySet()){
double countf_j = 1.0/1000000000.0;
if(featureInAns.containsKey(s)){
countf_j = featurecount.get(temp).get(s).count;//set the count
}
countf_temp.add(countf_j);
}
if(tempvalue.contains(temp)){
thresholdcount++;
double probab = ptemp;
for(Double countf_j:countf_temp){
probab = probab+Math.log(countf_j/count_temp);
}
// probab = probab/(countf_temp.size()+1);
this.threshold = this.threshold+Math.exp(probab);
thresholdcount++;
}
}
}
this.threshold = this.threshold/thresholdcount;
// System.out.println(threshold);
//step2: find the score
int correct = 0;
int total = 0;
for(sentence c: v){
ArrayList<String> word = c.wordAroundIndex(k);//get all the words
ArrayList<Integer> tempvalue = new ArrayList<Integer>(c.value);
//find the ones that are in the feature set
Hashtable<String, Integer> featureInAns = new Hashtable<String, Integer>();
for(String temp:word){
if(mostfrequent.containsKey(temp)){
featureInAns.put(temp, 1);
}
}
Set<Integer> myfeat = featurecount.keySet();
for(Integer temp: myfeat){
double ptemp = Math.log(probs.get((int)temp));
int count_temp = counts.get((int)temp);
ArrayList<Double> countf_temp = new ArrayList();
for(String s:mostfrequent.keySet()){
double countf_j = 1.0/1000000000.0;
if(featureInAns.containsKey(s)){
countf_j = featurecount.get(temp).get(s).count;//set the count
}
ptemp = ptemp + Math.log(countf_j/count_temp);
countf_temp.add(countf_j);
}
// ptemp = ptemp/(countf_temp.size()+1);
ptemp = Math.exp(ptemp);
if(ptemp>this.threshold){
//we predict true here
if(c.value.contains(temp)){
tp++;
correct++;
}
else{
fp++;
}
}
else{
if(!c.value.contains(temp)){
correct++;
}else{
fn++;
}
}
total++;
}
}
//step3: return the score
score = 1.0*correct/(1.0*total);
return score;
}
public ArrayList<Integer> predict(sentence c){
ArrayList<Integer> result = new ArrayList<Integer>();
for(int i = 0;i<probs.size();i++){
result.add(0);
}
boolean set = false;
ArrayList<String> word = c.wordAroundIndex(k);//get all the words
Hashtable<String, Integer> featureInAns = new Hashtable<String, Integer>();
for(String temp:word){
if(mostfrequent.containsKey(temp)){
featureInAns.put(temp, 1);
}
}
Set<Integer> myfeat = featurecount.keySet();
for(Integer temp: myfeat){
double ptemp = probs.get((int)temp);
int count_temp = counts.get((int)temp);
ArrayList<Double> countf_temp = new ArrayList();
for(String s:mostfrequent.keySet()){
double countf_j = 1.0/1000000000.0;//make it a really small number
if(featureInAns.containsKey(s)){
countf_j = featurecount.get(temp).get(s).count;//set the count
}
ptemp = ptemp *countf_j/count_temp;
countf_temp.add(countf_j);
}
// ptemp = ptemp/(countf_temp.size()+1);
if(ptemp>this.threshold){
//we predict true here
set = true;
result.remove((int)temp);
result.add((int)temp,1);
}
}
if(set){
result.add(0, 1);
}
else{
result.add(0, 0);
}
return result;
}
}