BM 25探索相関スコアアルゴリズム

7193 ワード

package com.btg.core.util.bm25;

import org.wltea.analyzer.core.IKSegmenter;
import org.wltea.analyzer.core.Lexeme;

import java.io.StringReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;

/**
 * https://www.jianshu.com/p/1e498888f505
 * https://www.cnblogs.com/jiangxinyang/p/10516302.html
 * https://www.zybuluo.com/evilking/note/902621
 * https://github.com/hankcs/HanLP/blob/master/src/main/java/com/hankcs/hanlp/summary/BM25.java
 * https://github.com/haifengl/smile/blob/master/nlp/src/main/java/smile/nlp/relevance/BM25.java
 * https://github.com/jllan/jannlp
 *
 * BM25      https://www.cnblogs.com/NaughtyBaby/p/9774836.html
 *
 *          
 *            
 */
public class BM25Test {

    //     
    private static final double k1 = 1.5;
    private static final double b = 0.75;

    /**
     *     
     * @return
     */
    private static List sentences() {
        List sentences = new ArrayList<>();
        sentences.add("Elasticsearch is a highly scalable open-source full-text search and analytics engine");
        sentences.add("It allows you to store, search, and analyze big volumes of data quickly and in near real time");
        sentences.add("is generally used as the underlying engine/technology that powers applications that have complex search features and requirements");
        sentences.add("You run an online web store where you allow your customers to search for products that you sell");
        sentences.add("You want to collect log or transaction data and you want to analyze and mine this data to look for trends, statistics, summarizations, or anomalies");
        return sentences;
    }

    /**
     *     
     * @return
     */
    private static String query() {
        return "want";
    }

    public static void main(String[] args) {
        List sentences = sentences();
        //        
        final int N = sentences.size();
        //          
        final double avgdl = avgdl(sentences);
        System.out.println("avgdl = " + avgdl);

        String query = query();
        System.out.println("     = " + query);

        List queryWords = ikanalyzer(query);
        System.out.println("         = " + queryWords);

        System.out.println("--------------------------------------------------------");

        System.out.println("     = " + sentences);
        List> allWords = new ArrayList<>();
        //       
        for(int i = 0, len = sentences.size(); i < len; i++) {
            allWords.add(i, ikanalyzer(sentences.get(i)));
        }
        System.out.println("         = " + allWords);
        System.out.println("--------------------------------------------------------");

        List> fs = f(allWords);
        System.out.println("                = " + fs);
        System.out.println("--------------------------------------------------------");

        Map nqis = nqi(allWords, queryWords);
        System.out.println("           = " + nqis);
        System.out.println("--------------------------------------------------------");

        Map idfs = idf(N, queryWords, nqis);
        System.out.println("                 = " + idfs);
        System.out.println("--------------------------------------------------------");

        for(int i = 0, len = sentences.size(); i < len; i++) {
            String sentence = sentences.get(i);
            double res = 0;
            for(String qw : queryWords) {
//                if(!sentence.contains(qw)) {
//                    continue;
//                }
                // fi qi   d      
                Double wi = idfs.get(qw);
                int fi = fs.get(i).getOrDefault(qw, 0);
//                System.out.println(sentence + " qw = " + qw + " fi = " + fi);
                double R = fi * (k1 + 1) / (fi + K(sentence, avgdl));
                res += wi * R;
            }
            System.out.println("   = " + sentence + "     = " + res);
        }
    }



    /**
     *   
     * @return
     */
    private static List ikanalyzer(String line) {
        StringReader re = new StringReader(line);
        IKSegmenter ik = new IKSegmenter(re,true);
        Lexeme lex = null;
        List words = new ArrayList<>();
        try {
            while((lex = ik.next()) != null){
                String text = lex.getLexemeText();
                words.add(text);
            }
            return words;
        }catch (Exception e) {
            e.printStackTrace();
        }
        return words;
    }

    /**
     *   qi         
     * @return
     */
    private static Map nqi(List> allWords, List queryWords) {
        Map nqis = new TreeMap<>();
        for(String qw : queryWords) {
            for(List aws : allWords) {
                if(aws.contains(qw)) {
                    nqis.put(qw, nqis.getOrDefault(qw, 0) + 1);
                }
            }
        }
        return nqis;
    }

    /**
     *   w(i),q(i)  
     *         
     * @return
     */
    private static Map idf(int N, List queryWords, Map nqis) {
        Map idfs = new HashMap<>();
        for(String qw : queryWords) {
//            System.out.println("   " + qw + "       = " + nqis.getOrDefault(qw, 0));
            double temp = (N - nqis.getOrDefault(qw, 0) + 0.5) / (nqis.getOrDefault(qw, 0) + 0.5);
            double idf = Math.log10(1 + temp);
//            System.out.println("    " + qw + ", idf = " + idf);
            idfs.put(qw, idf);
        }
        return idfs;
    }

    /**
     *                
     * @param allWords
     * @return
     */
    private static List> f(List> allWords) {
        List> fs = new ArrayList<>();
        for(List aw : allWords) {
            Map map = new TreeMap<>();
            for(String w : aw) {
                map.put(w, map.getOrDefault(w, 0) + 1);
            }
            fs.add(map);
        }
        return fs;
    }

    /**
     *    D          
     * @param sentences
     * @return
     */
    private static double avgdl(List sentences) {
        double totalLen = 0.0;
        for(String sentence : sentences) {
            totalLen += sentence.length();
        }
        return totalLen / sentences.size();
    }

    private static double K(String sentence, double avgdl) {
        int dl = sentence.length();
        return k1 * (1 - b + b * dl / avgdl);
    }

}