00001
00023 #ifndef _TRIE_H
00024 #define _TRIE_H
00025
00026 #include <assert.h>
00027 #include "defines.h"
00028
00029 using namespace std;
00030
00034 template<typename _type>
00035 class Trie {
00036 private:
00040 class Node {
00041 public:
00043
00045 const _type value;
00046 const Node* parent;
00047 unsigned long frequency;
00048 list<Node> childs;
00050
00051
00053 Node() : value(0), parent(NULL) {
00054 frequency = 1;
00055 }
00056
00058 Node(const _type val, const Node* parent) : value(val), parent(parent) {
00059 frequency = 1;
00060 }
00061
00063 string toString(unsigned int level) const {
00064 string ret, tab;
00065 char buf[500];
00066 for (unsigned int i=0; i<level; i++)
00067 tab += " ";
00068 sprintf(buf, "%s%ld (%c): %ld\n", tab.c_str(), (long) value, (char) value, frequency);
00069 ret += buf;
00070 for (typename list<Node>::const_iterator node = childs.begin(); node != childs.end(); node++) {
00071 ret += (*node).toString(level+1);
00072 }
00073 return ret;
00074 }
00075 };
00076
00078 Node root;
00079
00080 public:
00082 Trie();
00084 ~Trie();
00085
00087 bool findSequence(const list<_type>* sequence) const;
00088
00090 void insertSequence(const list<_type>* sequence, bool updateFrequency=false);
00092 double calculateProbability(const list<_type>* context, const _type nextValue) const;
00093
00094
00096 string toString() const;
00097
00099 string serialize() const;
00101 void unserialize(string data);
00102
00103 private:
00104
00105
00107 const Node* navigateTree(const list<_type>* sequence) const
00108 {
00109 const Node* cur = &root;
00110 for (typename list<_type>::const_iterator symbol = sequence->begin(); symbol != sequence->end(); symbol++) {
00111 bool foundsymb = false;
00112 for (typename list<Node>::const_iterator node = cur->childs.begin(); !foundsymb && node != cur->childs.end(); node++) {
00113 if ((*node).value == *symbol) {
00114
00115 foundsymb = true;
00116 cur = &(*node);
00117 }
00118 }
00119 if (!foundsymb)
00120
00121 return NULL;
00122 }
00123
00124 return cur;
00125 }
00126 };
00127
00128 template<typename _type>
00129 Trie<_type>::Trie() {}
00130
00131 template<typename _type>
00132 Trie<_type>::~Trie() {}
00133
00134 template<typename _type>
00135 bool Trie<_type>::findSequence(const list<_type>* sequence) const
00136 {
00137 return (navigateTree(sequence) != NULL);
00138 }
00139
00140 template<typename _type>
00141 void Trie<_type>::insertSequence(const list<_type>* sequence, bool updateFrequency)
00142 {
00143 Node* cur = &root;
00144 for (typename list<_type>::const_iterator symbol = sequence->begin(); symbol != sequence->end(); symbol++) {
00145 bool foundsymb = false;
00146 for (typename list<Node>::iterator node = cur->childs.begin(); !foundsymb && node != cur->childs.end(); node++) {
00147 if ((*node).value == *symbol) {
00148 foundsymb = true;
00149 cur = &(*node);
00150
00151
00152 }
00153 }
00154 if (!foundsymb) {
00155
00156 Node newNode(*symbol, cur);
00157 cur->childs.push_back(newNode);
00158 cur = &newNode;
00159 }
00160 }
00161
00162 if (updateFrequency)
00163 (*cur).frequency++;
00164 }
00165
00166 template<typename _type>
00167 double Trie<_type>::calculateProbability(const list<_type>* context, const _type nextValue) const
00168 {
00169
00170 double prob = 0.0, escapeProb = 1.0;
00171 unsigned long sumfreq, foundfreq;
00172
00173 const Node* cur = navigateTree(context);
00174 assert(cur != NULL);
00175
00176
00177
00178 if (cur->childs.size() == 0)
00179 cur = cur->parent;
00180
00181 while (cur != &root && cur != NULL) {
00182
00183 sumfreq = foundfreq = 0;
00184
00185 for (typename list<Node>::const_iterator node = cur->childs.begin(); node != cur->childs.end(); node++) {
00186
00187 sumfreq += (*node).frequency;
00188
00189 if ((*node).value == nextValue) {
00190 assert(foundfreq == 0);
00191 foundfreq = (*node).frequency;
00192 }
00193 }
00194 prob += escapeProb * ((double) foundfreq / cur->frequency);
00195 escapeProb *= ((double) sumfreq / cur->frequency);
00196
00197 cur = cur->parent;
00198 }
00199 sumfreq = foundfreq = 0;
00200 for (typename list<Node>::const_iterator node = cur->childs.begin(); node != cur->childs.end(); node++) {
00201 sumfreq += (*node).frequency;
00202
00203 if ((*node).value == nextValue) {
00204 assert(foundfreq == 0);
00205 foundfreq = (*node).frequency;
00206 }
00207 }
00208 prob += escapeProb *((double) foundfreq / sumfreq);
00209 return prob;
00210 }
00211
00212 template<typename _type>
00213 string Trie<_type>::toString() const
00214 {
00215 return root.toString(0);
00216 }
00217
00218 template<typename _type>
00219 string Trie<_type>::serialize() const
00220 {
00221 throw 1;
00222 return "";
00223 }
00224
00225 template<typename _type>
00226 void Trie<_type>::unserialize(string data)
00227 {
00228 throw 1;
00229 }
00230
00231 #endif