retrieval.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. #!/usr/bin/env python3
  2. import argparse
  3. import pandas as pd
  4. import pyterrier as pt
  5. id_template = 'spotify:episode:{}_{}'
  6. def parse_arguments():
  7. parser = argparse.ArgumentParser()
  8. parser.add_argument('data_properties')
  9. parser.add_argument('topics')
  10. parser.add_argument('run_id', type=str)
  11. parser.add_argument('--format',
  12. choices=['trec', 'submission'], default='trec')
  13. return parser.parse_args()
  14. def write_submission(df):
  15. print('RUNID QUERYID RANK SCORE EPISODEID OFFSET')
  16. for t in df.itertuples():
  17. qid, rank, docno, score = t[1], t[2], t[3], t[4]
  18. episode, timestamp = docno.split('_')
  19. start_time = str(float(timestamp.split('-')[0]))
  20. episode_id = id_template.format(episode, start_time)
  21. print('{} {} {} {} {} {}'.format(
  22. args.run_id, qid, rank, score, episode_id, start_time))
  23. def write_trec(df):
  24. print('query-id Q0 document-id rank score STANDARD')
  25. for t in df.itertuples():
  26. qid, rank, docno, score = t[1], t[2], t[3], t[4]
  27. episode, timestamp = docno.split('_')
  28. start_time = str(float(timestamp.split('-')[0]))
  29. episode_id = id_template.format(episode, start_time)
  30. print('{} {} {} {} {} {}'.format(
  31. qid, '0', episode_id, rank, score, args.run_id))
  32. if __name__=="__main__":
  33. args = parse_arguments()
  34. pt.init()
  35. index_dir = './' + args.data_properties
  36. index_ref = pt.IndexRef.of(index_dir)
  37. index = pt.IndexFactory.of(index_ref)
  38. topics = pt.Utils.parse_trecxml_topics_file(args.topics)
  39. retr = pt.BatchRetrieve(index)
  40. res = retr.transform(topics)
  41. df = pd.DataFrame(res, columns=['qid', 'rank', 'docno', 'score'])
  42. if args.format == 'trec':
  43. write_trec(df)
  44. elif args.format == 'submission':
  45. write_submission(df)