presage 0.9.1
ARPAPredictor.cpp
Go to the documentation of this file.
1
2/******************************************************
3 * Presage, an extensible predictive text entry system
4 * ---------------------------------------------------
5 *
6 * Copyright (C) 2008 Matteo Vescovi <matteo.vescovi@yahoo.co.uk>
7
8 This program is free software; you can redistribute it and/or modify
9 it under the terms of the GNU General Public License as published by
10 the Free Software Foundation; either version 2 of the License, or
11 (at your option) any later version.
12
13 This program is distributed in the hope that it will be useful,
14 but WITHOUT ANY WARRANTY; without even the implied warranty of
15 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 GNU General Public License for more details.
17
18 You should have received a copy of the GNU General Public License along
19 with this program; if not, write to the Free Software Foundation, Inc.,
20 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21 *
22 **********(*)*/
23
24
25#include "ARPAPredictor.h"
26
27
28#include <sstream>
29#include <algorithm>
30#include <cmath>
31
32
33#define OOV "<UNK>"
34
35
36
39 ct,
40 name,
41 "ARPAPredictor, a predictor relying on an ARPA language model",
42 "ARPAPredictor, long description."
43 ),
44 dispatcher (this)
45{
46 LOGGER = PREDICTORS + name + ".LOGGER";
47 ARPAFILENAME = PREDICTORS + name + ".ARPAFILENAME";
48 VOCABFILENAME = PREDICTORS + name + ".VOCABFILENAME";
49 TIMEOUT = PREDICTORS + name + ".TIMEOUT";
50
51 // build notification dispatch map
56
59}
60
61void ARPAPredictor::set_vocab_filename (const std::string& value)
62{
63 logger << INFO << "VOCABFILENAME: " << value << endl;
64 vocabFilename = value;
65}
66
67void ARPAPredictor::set_arpa_filename (const std::string& value)
68{
69 logger << INFO << "ARPAFILENAME: " << value << endl;
70 arpaFilename = value;
71}
72
73void ARPAPredictor::set_timeout (const std::string& value)
74{
75 logger << INFO << "TIMEOUT: " << value << endl;
76 timeout = atoi(value.c_str());
77}
78
80{
81 std::ifstream vocabFile;
82 vocabFile.open(vocabFilename.c_str());
83 if(!vocabFile)
84 logger << ERROR << "Error opening vocabulary file: " << vocabFilename << endl;
85
86 assert(vocabFile);
87 std::string row;
88 int code = 0;
89 while(std::getline(vocabFile,row))
90 {
91 if(row[0]=='#')
92 continue;
93
94 vocabCode[row]=code;
95 vocabDecode[code]=row;
96
97 logger << DEBUG << "["<<row<<"] -> "<< code<<endl;
98
99 code++;
100 }
101
102 logger << DEBUG << "Loaded "<<code<<" words from vocabulary" <<endl;
103
104}
105
107{
108 std::ifstream arpaFile;
109 arpaFile.open(arpaFilename.c_str());
110
111 if(!arpaFile)
112 logger << ERROR << "Error opening ARPA model file: " << arpaFilename << endl;
113
114 assert(arpaFile);
115 std::string row;
116
117 int currOrder = 0;
118
119 unigramCount = 0;
120 bigramCount = 0;
121 trigramCount = 0;
122
123 int lineNum =0;
124 bool startData = false;
125
126 while(std::getline(arpaFile,row))
127 {
128 lineNum++;
129 if(row.empty())
130 continue;
131
132 if(row == "\\end\\")
133 break;
134
135 if(row == "\\data\\")
136 {
137 startData = true;
138 continue;
139 }
140
141
142 if( startData == true && currOrder == 0)
143 {
144 if( row.find("ngram 1")==0 )
145 {
146 unigramTot = atoi(row.substr(8).c_str());
147 logger << DEBUG << "tot unigram = "<<unigramTot<<endl;
148 continue;
149 }
150
151 if( row.find("ngram 2")==0)
152 {
153 bigramTot = atoi(row.substr(8).c_str());
154 logger << DEBUG << "tot bigram = "<<bigramTot<<endl;
155 continue;
156 }
157
158 if( row.find("ngram 3")==0)
159 {
160 trigramTot = atoi(row.substr(8).c_str());
161 logger << DEBUG << "tot trigram = "<<trigramTot<<endl;
162 continue;
163 }
164 }
165
166 if( row == "\\1-grams:" && startData)
167 {
168 currOrder = 1;
169 std::cerr << std::endl << "ARPA loading unigrams:" << std::endl;
170 unigramProg = new ProgressBar<char>(std::cerr);
171 continue;
172 }
173
174 if( row == "\\2-grams:" && startData)
175 {
176 currOrder = 2;
177 std::cerr << std::endl << std::endl << "ARPA loading bigrams:" << std::endl;
178 bigramProg = new ProgressBar<char>(std::cerr);
179 continue;
180 }
181
182 if( row == "\\3-grams:" && startData)
183 {
184 currOrder = 3;
185 std::cerr << std::endl << std::endl << "ARPA loading trigrams:" << std::endl;
186 trigramProg = new ProgressBar<char>(std::cerr);
187 continue;
188 }
189
190 if(currOrder == 0)
191 continue;
192
193 switch(currOrder)
194 {
195 case 1: addUnigram(row);
196 break;
197
198 case 2: addBigram(row);
199 break;
200
201 case 3: addTrigram(row);
202 break;
203 }
204
205 }
206
207 std::cerr << std::endl << std::endl;
208
209 logger << DEBUG << "loaded unigrams: "<< unigramCount << endl;
210 logger << DEBUG << "loaded bigrams: " << bigramCount << endl;
211 logger << DEBUG << "loaded trigrams: "<< trigramCount << endl;
212}
213
214void ARPAPredictor::addUnigram(std::string row)
215{
216 std::stringstream str(row);
217 float logProb = 0;
218 float logAlfa = 0;
219 std::string wd1Str;
220
221 str >> logProb;
222 str >> wd1Str;
223 str >> logAlfa;
224
225
226 if(wd1Str != OOV )
227 {
228 int wd1 = vocabCode[wd1Str];
229
230 unigramMap[wd1]= ARPAData(logProb,logAlfa);
231
232 logger << DEBUG << "adding unigram ["<<wd1Str<< "] -> "<<logProb<<" "<<logAlfa<<endl;
233 }
234
235
236 unigramCount++;
237
239}
240
241void ARPAPredictor::addBigram(std::string row)
242{
243 std::stringstream str(row);
244 float logProb = 0;
245 float logAlfa = 0;
246 std::string wd1Str;
247 std::string wd2Str;
248
249 str >> logProb;
250 str >> wd1Str;
251 str >> wd2Str;
252 str >> logAlfa;
253
254 if(wd1Str != OOV && wd2Str != OOV)
255 {
256 int wd1 = vocabCode[wd1Str];
257 int wd2 = vocabCode[wd2Str];
258
259 bigramMap[BigramKey(wd1,wd2)]=ARPAData(logProb,logAlfa);
260
261 logger << DEBUG << "adding bigram ["<<wd1Str<< "] ["<<wd2Str<< "] -> "<<logProb<<" "<<logAlfa<<endl;
262 }
263
264 bigramCount++;
265 bigramProg->update((float)bigramCount/(float)bigramTot);
266}
267
268void ARPAPredictor::addTrigram(std::string row)
269{
270 std::stringstream str(row);
271 float logProb = 0;
272
273 std::string wd1Str;
274 std::string wd2Str;
275 std::string wd3Str;
276
277 str >> logProb;
278 str >> wd1Str;
279 str >> wd2Str;
280 str >> wd3Str;
281
282 if(wd1Str != OOV && wd2Str != OOV && wd3Str != OOV)
283 {
284 int wd1 = vocabCode[wd1Str];
285 int wd2 = vocabCode[wd2Str];
286 int wd3 = vocabCode[wd3Str];
287
288 trigramMap[TrigramKey(wd1,wd2,wd3)]=logProb;
289 logger << DEBUG << "adding trigram ["<<wd1Str<< "] ["<<wd2Str<< "] ["<<wd3Str<< "] -> "<<logProb <<endl;
290
291 }
292
293 trigramCount++;
295}
296
297
299{
300 delete unigramProg;
301 delete bigramProg;
302 delete trigramProg;
303}
304
305bool ARPAPredictor::matchesPrefixAndFilter(std::string word, std::string prefix, const char** filter ) const
306{
307 if(filter == 0)
308 return word.find(prefix)==0;
309
310 for(int j = 0; filter[j] != 0; j++)
311 {
312 std::string pattern = prefix+std::string(filter[j]);
313 if(word.find(pattern)==0)
314 return true;
315 }
316
317 return false;
318}
319
320Prediction ARPAPredictor::predict(const size_t max_partial_prediction_size, const char** filter) const
321{
322 logger << DEBUG << "predict()" << endl;
323 Prediction prediction;
324
325 int cardinality = 3;
326 std::vector<std::string> tokens(cardinality);
327
328 std::string prefix = Utility::strtolower(contextTracker->getToken(0));
329 std::string wd2Str = Utility::strtolower(contextTracker->getToken(1));
330 std::string wd1Str = Utility::strtolower(contextTracker->getToken(2));
331
332 std::multimap< float, std::string, cmp > result;
333
334 logger << DEBUG << "["<<wd1Str<<"]"<<" ["<<wd2Str<<"] "<<"["<<prefix<<"]"<<endl;
335
336 //search for the past tokens in the vocabulary
337 std::map<std::string,int>::const_iterator wd1It,wd2It;
338 wd1It = vocabCode.find(wd1Str);
339 wd2It = vocabCode.find(wd2Str);
340
346 //we have two valid past tokens available
347 if(wd1It!=vocabCode.end() && wd2It!=vocabCode.end())
348 {
349 //iterate over all vocab words
350 for(std::map<int,std::string>::const_iterator it = vocabDecode.begin(); it!=vocabDecode.end(); ++it) //cppcheck: Prefer prefix ++/-- operators for non-primitive types.
351 {
352 //if wd3 matches prefix and filter -> compute its backoff probability and add to the result set
353 if(matchesPrefixAndFilter(it->second,prefix,filter))
354 {
355 std::pair<const float,std::string> p (computeTrigramBackoff(wd1It->second,wd2It->second,it->first),
356 it->second);
357 result.insert(p);
358 }
359 }
360 }
361
362 //we have one valid past token available
363 else if(wd2It!=vocabCode.end())
364 {
365 //iterate over all vocab words
366 for(std::map<int,std::string>::const_iterator it = vocabDecode.begin(); it!=vocabDecode.end(); ++it)
367 {
368 //if wd3 matches prefix and filter -> compute its backoff probability and add to the result set
369 if(matchesPrefixAndFilter(it->second,prefix,filter))
370 {
371 std::pair<const float,std::string> p(computeBigramBackoff(wd2It->second,it->first),
372 it->second);
373 result.insert(p);
374 }
375 }
376 }
377
378 //we have no valid past token available
379 else
380 {
381 //iterate over all vocab words
382 for(std::map<int,std::string>::const_iterator it = vocabDecode.begin(); it!=vocabDecode.end(); ++it)
383 {
384 //if wd3 matches prefix and filter -> compute its backoff probability and add to the result set
385 if(matchesPrefixAndFilter(it->second,prefix,filter))
386 {
387 std::pair<const float,std::string> p (unigramMap.find(it->first)->second.logProb,
388 it->second);
389 result.insert(p);
390 }
391 }
392 }
393
394
395 size_t numSuggestions = 0;
396 for(std::multimap< float, std::string, cmp >::const_iterator it = result.begin();
397 it != result.end() && numSuggestions < max_partial_prediction_size;
398 ++it)
399 {
400 prediction.addSuggestion(Suggestion(it->second,exp(it->first)));
401 numSuggestions++;
402 }
403
404 return prediction;
405}
409float ARPAPredictor::computeTrigramBackoff(int wd1,int wd2,int wd3) const
410{
411 logger << DEBUG << "computing P( ["<<vocabDecode.find(wd3)->second<< "] | ["<<vocabDecode.find(wd1)->second<<"] ["<<vocabDecode.find(wd2)->second<<"] )"<<endl;
412
413 //trigram exist
414 std::map<TrigramKey,float>::const_iterator trigramIt =trigramMap.find(TrigramKey(wd1,wd2,wd3));
415 if(trigramIt!=trigramMap.end())
416 {
417 logger << DEBUG << "trigram ["<<vocabDecode.find(wd1)->second<< "] ["<<vocabDecode.find(wd2)->second<< "] ["<<vocabDecode.find(wd3)->second<< "] exists" <<endl;
418 logger << DEBUG << "returning "<<trigramIt->second <<endl;
419 return trigramIt->second;
420 }
421
422 //bigram exist
423 std::map<BigramKey,ARPAData>::const_iterator bigramIt =bigramMap.find(BigramKey(wd1,wd2));
424 if(bigramIt!=bigramMap.end())
425 {
426 logger << DEBUG << "bigram ["<<vocabDecode.find(wd1)->second<< "] ["<<vocabDecode.find(wd2)->second<< "] exists" <<endl;
427 float prob = bigramIt->second.logAlfa + computeBigramBackoff(wd2,wd3);
428 logger << DEBUG << "returning "<<prob<<endl;
429 return prob;
430 }
431
432 //else
433 logger << DEBUG << "no bigram w1,w2 exist" <<endl;
434 float prob = computeBigramBackoff(wd2,wd3);
435 logger << DEBUG << "returning "<<prob<<endl;
436 return prob;
437
438}
439
443float ARPAPredictor::computeBigramBackoff(int wd1, int wd2) const
444{
445 //bigram exist
446 std::map<BigramKey,ARPAData>::const_iterator bigramIt =bigramMap.find(BigramKey(wd1,wd2));
447 if(bigramIt!=bigramMap.end())
448 return bigramIt->second.logProb;
449
450 //else
451 return unigramMap.find(wd1)->second.logAlfa +unigramMap.find(wd2)->second.logProb;
452
453}
454
455void ARPAPredictor::learn(const std::vector<std::string>& change)
456{
457 logger << DEBUG << "learn() method called" << endl;
458 logger << DEBUG << "learn() method exited" << endl;
459}
460
462{
463 logger << DEBUG << "About to invoke dispatcher: " << var->get_name () << " - " << var->get_value() << endl;
464 dispatcher.dispatch (var);
465}
#define OOV
virtual Prediction predict(const size_t size, const char **filter) const
Generate prediction.
void addBigram(std::string)
ProgressBar< char > * bigramProg
void addUnigram(std::string)
std::string vocabFilename
std::map< std::string, int > vocabCode
void createARPATable()
std::map< TrigramKey, float > trigramMap
void set_arpa_filename(const std::string &value)
virtual void update(const Observable *variable)
float computeBigramBackoff(int, int) const
std::string VOCABFILENAME
Dispatcher< ARPAPredictor > dispatcher
std::map< int, std::string > vocabDecode
ProgressBar< char > * unigramProg
bool matchesPrefixAndFilter(std::string, std::string, const char **) const
float computeTrigramBackoff(int, int, int) const
ProgressBar< char > * trigramProg
virtual void learn(const std::vector< std::string > &change)
std::string arpaFilename
std::string ARPAFILENAME
void set_vocab_filename(const std::string &value)
ARPAPredictor(Configuration *, ContextTracker *, const char *)
void addTrigram(std::string)
void set_timeout(const std::string &value)
std::map< int, ARPAData > unigramMap
std::map< BigramKey, ARPAData > bigramMap
std::string TIMEOUT
std::string LOGGER
void loadVocabulary()
Tracks user interaction and context.
std::string getToken(const int) const
void dispatch(const Observable *var)
Definition: dispatcher.h:73
void map(Observable *var, const mbr_func_ptr_t &ptr)
Definition: dispatcher.h:62
virtual std::string get_name() const =0
virtual std::string get_value() const =0
void addSuggestion(Suggestion)
Definition: prediction.cpp:90
ContextTracker * contextTracker
Definition: predictor.h:83
const std::string PREDICTORS
Definition: predictor.h:81
virtual void set_logger(const std::string &level)
Definition: predictor.cpp:88
Logger< char > logger
Definition: predictor.h:87
const std::string name
Definition: predictor.h:77
void update(const double percentage)
Definition: progress.h:54
static char * strtolower(char *)
Definition: utility.cpp:42
const Logger< _charT, _Traits > & endl(const Logger< _charT, _Traits > &lgr)
Definition: logger.h:278
std::string config
Definition: presageDemo.cpp:70