anygramwrapper.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405
  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. """
  4. Wrapper around any-gram kernel
  5. Author: Rasoul Kaljahi
  6. See LICENSE file.
  7. """
  8. from sklearn import svm
  9. from sklearn.multiclass import OneVsRestClassifier
  10. import pickle
  11. import numpy as np
  12. class AnyGramWrapper:
  13. '''
  14. Wrapper class around any-gram kernel to be used by scikit learn SVC
  15. '''
  16. def __init__(self, pAnyGram):
  17. '''
  18. Constructor
  19. '''
  20. self.trainXs = None
  21. self.trainYs = None
  22. self.trainAuxs = None # auxiliary training input
  23. self.testXs = None
  24. self.testYs = None
  25. self.testAuxs = None # auxiliary test input
  26. self.ag = pAnyGram
  27. self._pcTrainKernel = None # precomputed kernel from the preloaded training data
  28. self._pcTestKernel = None # precomputed kernel from the preloaded test data
  29. self._model = None
  30. def loadTrainSet(self, pXs, pYs):
  31. '''
  32. Loads the training data
  33. The data set is provided via two lists, one containing a list of text instances (preferably tokenized) and one
  34. containing the labels matching the text instances.
  35. '''
  36. self.trainXs = self.ag.formatData(pXs)
  37. self.trainYs = [int(y) for y in pYs]
  38. self._pcTrainKernel = None
  39. def loadTrainAux(self, pAux):
  40. '''
  41. Loads auxiliary training input
  42. The input must be a 3D array/list of shape (#train_innstances, #max_text_length, #auxiliary_dimension
  43. '''
  44. if self.trainXs is None or len(self.trainXs) == 0:
  45. raise Exception("Training dataset should be loaded first!")
  46. self.trainAuxs = self.formatAux(pAux)
  47. def loadTestSet(self, pXs, pYs):
  48. '''
  49. Loads the training data
  50. The data set is provided via two lists, one containing a list of text instances (preferably tokenized) and one
  51. containing the labels matching the text instances.
  52. '''
  53. self.testXs = self.ag.formatData(pXs)
  54. self.testYs = [int(y) for y in pYs]
  55. self._pcTestKernel = None
  56. def loadTestAux(self, pAux):
  57. '''
  58. Loads auxiliary test input
  59. The input must be a 3D array/list of shape (#train_innstances, #max_text_length, #auxiliary_dimension
  60. '''
  61. if self.testXs is None or len(self.testXs) == 0:
  62. raise Exception("Test dataset should be loaded first!")
  63. self.testAuxs = self.formatAux(pAux)
  64. def formatAux(self, pAux):
  65. '''
  66. Formats the auxiliary data
  67. The input is a 3D list of shape (#train_innstances, #max_text_length, #auxiliary_dimension)
  68. '''
  69. if isinstance(pAux, list):
  70. # padding the 2nd dimensions of list (#max_text_length)
  71. vFixedDimList = [s[:self.ag.maxTxtLen] + [[0] * len(s[0])] * (self.ag.maxTxtLen - len(s)) for s in pAux]
  72. for s in vFixedDimList:
  73. if len(s) != 80:
  74. print len(s)
  75. for t in s:
  76. if len(t) != 1:
  77. print len(t)
  78. for a in t:
  79. if not isinstance(a, int):
  80. print type(a)
  81. return np.array(vFixedDimList, dtype=np.float64)
  82. else:
  83. return pAux
  84. def loadEmbeddings(self, pWEFilename, pIsLowerCase):
  85. '''
  86. Loads word embeddings from the given file
  87. Embeddings should be loaded before the data, because the vocabulary used for formatting data should
  88. be extracted from the word embeddings (when the word embeddings are used).
  89. pIsLowerCase specifies whether the vocabulary of the word embedding is lowercased.
  90. '''
  91. self.ag.loadEmbeddings(pWEFilename, pIsLowerCase)
  92. def precomputeTrainKernel(self):
  93. '''
  94. Precomputes the kernel from the training data
  95. '''
  96. self._pcTrainKernel = self._preComputeKernel(self.trainXs, self.trainXs, self.trainAuxs, self.trainAuxs)
  97. def _getPrecomputedTrainKernel(self):
  98. '''
  99. Returns the precomputed kernel from the training data
  100. If the kernel is not precomputed yet, it will do so first.
  101. '''
  102. if self._pcTrainKernel is None:
  103. self.precomputeTrainKernel()
  104. return self._pcTrainKernel
  105. def precomputeTestKernel(self):
  106. '''
  107. Precomputes the kernel from the test data
  108. '''
  109. self._pcTestKernel = self._preComputeKernel(self.testXs, self.trainXs, self.testAuxs, self.trainAuxs)
  110. def _getPrecomputedTestKernel(self):
  111. '''
  112. Returns the precomputed kernel from the test data
  113. If the kernel is not precomputed yet, it will do so first.
  114. '''
  115. if self._pcTestKernel is None:
  116. self.precomputeTestKernel()
  117. return self._pcTestKernel
  118. def _preComputeKernel(self, pX1, pX2, pAux1 = None, pAux2 = None):
  119. '''
  120. Computes and returns kernel with the given data
  121. The data should be formatted by AnyGram.formatData().
  122. '''
  123. if self.trainXs is None or len(self.trainXs) == 0:
  124. raise Exception("Training dataset is empty!")
  125. if pAux1 is not None and pAux2 is not None:
  126. return self.ag.computeKernel(pX1.astype(np.float64),
  127. pX2.astype(np.float64),
  128. pAux1.astype(np.float64),
  129. pAux2.astype(np.float64))
  130. else:
  131. return self.ag.computeKernel(pX1.astype(np.float64),
  132. pX2.astype(np.float64))
  133. def combinePrecomputedTrainKernel(self, paKernelMatrix, pCombMethod):
  134. '''
  135. Combines the anygram kernel computed on the training set here with any given kernel matrix using the specified method
  136. The given kernel matrix should match the computed any-gram kernel matrix in shape.
  137. The combination methods supported here are:
  138. + or add: add corresponding elements in the two kernel matrices
  139. * or multiply: multiply corresponding elements in the two kernel matrices
  140. arith: arithmetic mean of the corresponding elements in the two kernel matrices
  141. geo: geometric mean
  142. '''
  143. if self._pcTrainKernel in None:
  144. raise Exception("Kernel is not precomputed yet. Run precomputeTrainKernel() first.")
  145. self._pcTrainKernel = self._combineKernels(self._pcTrainKernel, paKernelMatrix, pCombMethod)
  146. def combinePrecomputedTestKernel(self, paKernelMatrix, pCombMethod):
  147. '''
  148. Combines the anygram kernel computed on the test set here with any given kernel matrix using the specified method
  149. The given kernel matrix should match the computed any-gram kernel matrix in shape.
  150. The combination methods supported here are:
  151. + or add: add corresponding elements in the two kernel matrices
  152. * or multiply: multiply corresponding elements in the two kernel matrices
  153. arith: arithmetic mean of the corresponding elements in the two kernel matrices
  154. geo: geometric mean
  155. '''
  156. if self._pcTestKernel in None:
  157. raise Exception("Kernel is not precomputed yet. Run precomputeTrainKernel() first.")
  158. self._pcTestKernel = self._combineKernels(self._pcTestKernel, paKernelMatrix, pCombMethod)
  159. def _combineKernels(self, paKernelMatrix1, paKernelMatrix2, pCombMethod):
  160. '''
  161. Combines and returns two given kernel matrices with the givem method
  162. The given kernel matrices should have the same shapes.
  163. The combination methods supported here are:
  164. + or add: add corresponding elements in the two kernel matrices
  165. * or multiply: multiply corresponding elements in the two kernel matrices
  166. arith: arithmetic mean of the corresponding elements in the two kernel matrices
  167. geo: geometric mean
  168. '''
  169. if paKernelMatrix1.shape != paKernelMatrix2.shape:
  170. raise Exception("The shape of the given kernel matrix is not valid: " % paKernelMatrix1.shape)
  171. if pCombMethod.lower() in ['+', "add"]:
  172. return np.add(paKernelMatrix1, paKernelMatrix1)
  173. elif pCombMethod.lower() in ['*', "multiply"]:
  174. return np.multiply(paKernelMatrix1, paKernelMatrix1)
  175. elif pCombMethod.lower().startswith("arith"):
  176. return np.add(paKernelMatrix1, paKernelMatrix2) / 2
  177. elif pCombMethod.lower().startswith("geo"):
  178. return np.sqrt(np.multiply(paKernelMatrix1, paKernelMatrix2))
  179. def train(self, pflgUsePrecompKernel = False, pMCMethod = None, C = 1, class_weight = None):
  180. '''
  181. Trains and returns anygram model
  182. If pflgUsePrecompKernel is set to true, SVC will use precomputed kernel. This can save time when the data or
  183. kernel computation parameters remain the same in repeated trainings (e.g. in tunning).
  184. pMCMethod is decision_function_shape parameter of the scikit.svm.SVC and specifies the multiclass classification
  185. method. The othe parameters are those of scikit.svm.SVC.
  186. '''
  187. if self.trainXs is None or len(self.trainXs) == 0:
  188. raise Exception("Training dataset is empty!")
  189. if pflgUsePrecompKernel:
  190. vKernel = "precomputed"
  191. X = self._getPrecomputedTrainKernel()
  192. else:
  193. vKernel = self.ag
  194. X = self.trainXs
  195. if pMCMethod is None:
  196. vSVC = svm.SVC(kernel = vKernel, C = C, class_weight = class_weight)
  197. elif pMCMethod.lower() == "ovo":
  198. vSVC = svm.SVC(kernel = vKernel, decision_function_shape = "ovo", C = C, class_weight = class_weight)
  199. elif pMCMethod.lower() in ["ova", "ovr"]:
  200. vSVC = OneVsRestClassifier(svm.SVC(kernel = vKernel, decision_function_shape = pMCMethod, C = C, class_weight = class_weight))
  201. else:
  202. raise Exception("Unknown multiclass classification method: %s", pMCMethod)
  203. # training
  204. self._model = vSVC.fit(X, self.trainYs)
  205. @property
  206. def model(self):
  207. '''
  208. Trained scikit SVC model
  209. '''
  210. return self._model
  211. def saveModel(self, pFilename):
  212. '''
  213. Saves the SVC model by pickling it to a given file
  214. '''
  215. ## ToDo: when saving the model, the vocabulary and the embeddings (for WESS) must also be saved
  216. pickle.dump(self._model, open(pFilename, 'w'))
  217. def loadModel(self, pModelPickle):
  218. '''
  219. Loads the pickled SVC model
  220. '''
  221. ## ToDo: model should have been saved with a vocabulary and embedding which must also be loaded here
  222. self._model = pickle.load(open(pModelPickle))
  223. def test(self, pTestXs = None, pTestYs = None):
  224. '''
  225. Tests the given models on the loaded test set
  226. '''
  227. if pTestXs is not None and pTestYs is not None:
  228. self.loadTestSet(pXs = pTestXs, pYs = pTestYs)
  229. if isinstance(self._model, OneVsRestClassifier):
  230. vKernel = self._model.estimators_[0].kernel
  231. elif isinstance(self._model, svm.SVC):
  232. vKernel = self._model.kernel
  233. # prediction
  234. if vKernel == "precomputed":
  235. vaPreds = self._model.predict(self._getPrecomputedTestKernel())
  236. else:
  237. vaPreds = self._model.predict(self.testXs)
  238. # scoring
  239. vScore = self._score(vaPreds, self.testYs)
  240. return vaPreds, vScore
  241. def predict(self, pXs, pAux = None):
  242. '''
  243. Predicts the labels of the given data
  244. '''
  245. vaXs = self.ag.formatData(pXs)
  246. vaAuxs = self.formatAux(pAux)
  247. if isinstance(self._model, OneVsRestClassifier):
  248. vKernel = self._model.estimators_[0].kernel
  249. elif isinstance(self._model, svm.SVC):
  250. vKernel = self._model.kernel
  251. if vKernel == "precomputed":
  252. return self._model.predict(self._preComputeKernel(vaXs, self.trainXs, vaAuxs, self.trainAuxs))
  253. else:
  254. return self._model.predict(vaXs)
  255. def _score(self, plPreds, plGolds):
  256. '''
  257. Scores the given predictions
  258. '''
  259. vCorrect = 0
  260. for p, g in zip(plPreds, plGolds):
  261. if p == g:
  262. vCorrect += 1
  263. return vCorrect * 1.0 / len(plPreds)