AlgService.java 6.75 KB
package com.viontech.match.service;

import com.viontech.match.entity.Person;
import com.viontech.match.entity.MatchResult;
import org.elasticsearch.action.admin.indices.delete.DeleteIndexRequest;
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.delete.DeleteRequest;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.client.RequestOptions;
import org.elasticsearch.client.RestHighLevelClient;
import org.elasticsearch.client.indices.CreateIndexRequest;
import org.elasticsearch.client.indices.CreateIndexResponse;
import org.elasticsearch.client.indices.GetIndexRequest;
import org.elasticsearch.client.indices.GetIndexResponse;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.functionscore.ScriptScoreQueryBuilder;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptType;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.springframework.stereotype.Service;

import javax.annotation.Resource;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * .
 *
 * @author 谢明辉
 * @date 2020/11/20
 */

@Service
public class AlgService {

    @Resource
    private RestHighLevelClient client;

    /**
     * 建立特征库索引
     *
     * @param poolName 特征库名称,作为index
     */
    public Object createPool(String poolName) throws IOException {
        CreateIndexRequest test2 = new CreateIndexRequest(poolName);
        XContentBuilder builder = XContentFactory.jsonBuilder();
        builder.startObject();
        {
            builder.startObject("properties");
            {
                builder.startObject("data");
                {
                    builder.field("type", "dense_vector");
                    builder.field("dims", 512);
                }
                builder.endObject();
                builder.startObject("personId");
                {
                    builder.field("type", "keyword");
                }
                builder.endObject();
            }
            builder.endObject();

        }
        builder.endObject();
        test2.mapping(builder);
        CreateIndexResponse createIndexResponse = client.indices().create(test2, RequestOptions.DEFAULT);
        if (createIndexResponse.isAcknowledged()) {
            return "success";
        } else {
            return "failed";
        }
    }

    /**
     * 删除特征库
     *
     * @param poolName 特征库名称,index
     */
    public Object deletePool(String poolName) throws IOException {

        boolean exists = client.indices().exists(new GetIndexRequest(poolName), RequestOptions.DEFAULT);
        if (!exists) {
            return "特征池不存在";
        }
        DeleteIndexRequest deleteIndexRequest = new DeleteIndexRequest(poolName);
        AcknowledgedResponse delete = client.indices().delete(deleteIndexRequest, RequestOptions.DEFAULT);
        if (delete.isAcknowledged()) {
            return "success";
        } else {
            return "failed";
        }

    }

    /**
     * 查询特征库列表
     */
    public String[] queryPoolList() throws IOException {
        GetIndexResponse response = client.indices().get(new GetIndexRequest("*"), RequestOptions.DEFAULT);
        return response.getIndices();
    }

    /**
     * 人员比对
     *
     * @param person 需要比对的人员
     * @param poolId 用来比对的特征库
     */
    public List<MatchResult> matchPerson(Person person, String poolId) throws Exception {
        Double[] feature = person.getFeature();
        if (feature.length < 512) {
            throw new Exception("特征维数为:" + feature.length + ",小于512维");
        }
        SearchRequest searchRequest = new SearchRequest(poolId);
        SearchSourceBuilder builder = new SearchSourceBuilder();
        builder.fetchSource("personId", null);
        builder.size(40);
        Map<String, Object> params = new HashMap<>(1);
        params.put("data", feature);
        Script script = new Script(
                ScriptType.INLINE, Script.DEFAULT_SCRIPT_LANG,
                "(cosineSimilarity(params.data, 'data') + 1) / 2 * 100", params);

        ScriptScoreQueryBuilder queryBuilder = QueryBuilders.scriptScoreQuery(QueryBuilders.matchAllQuery(), script);
        builder.query(queryBuilder);
        builder.fetchSource("personId", null);

        SearchRequest source = searchRequest.source(builder);
        SearchResponse search = client.search(source, RequestOptions.DEFAULT);
        SearchHits hits = search.getHits();
        SearchHit[] hits1 = hits.getHits();
        ArrayList<MatchResult> matchResults = new ArrayList<>();
        for (SearchHit item : hits1) {
            String personId = item.getId();
            float score = item.getScore();
            String index = item.getIndex();
            MatchResult matchResult = new MatchResult().setPersonId(personId).setScore(score).setPoolId(index);
            matchResults.add(matchResult);
        }
        return matchResults;
    }

    /**
     * 添加人员
     *
     * @param person 需要添加的人员
     * @param poolId 特征库名称
     */
    public Object addPerson(Person person, String poolId) throws IOException {
        IndexRequest indexRequest = new IndexRequest(poolId)
                .id(person.getId())
                .source(XContentType.JSON, "personId", person.getPersonId(), "data", person.getFeature());
        return client.index(indexRequest, RequestOptions.DEFAULT);
    }

    public Object addAllPerson(List<Person> persons, String poolId) throws Exception {
        BulkRequest bulkRequest = new BulkRequest();
        for (Person person : persons) {
            bulkRequest.add(new IndexRequest(poolId)
                    .id(person.getId())
                    .source(XContentType.JSON, "personId", person.getPersonId(), "data", person.getFeature()));
        }
        return client.bulk(bulkRequest, RequestOptions.DEFAULT);
    }

    /**
     * 删除人员
     *
     * @param personId 需要删除的人员的id
     * @param poolId   特征池名称
     */
    public Object delPerson(String personId, String poolId) throws Exception {
        DeleteRequest deleteRequest = new DeleteRequest(poolId, personId);
        return client.delete(deleteRequest, RequestOptions.DEFAULT);
    }


}