ElasticSearch 工具类封装(基于ElasticsearchTemplate)

1.抽象接口定义

  1 public abstract class SearchQueryEngine<T> {
  2 
  3     @Autowired
  4     protected ElasticsearchTemplate elasticsearchTemplate;
  5 
  6     public abstract int saveOrUpdate(List<T> list);
  7 
  8     public abstract <R> List<R> aggregation(T query, Class<R> clazz);
  9 
 10     public abstract <R> Page<R> scroll(T query, Class<R> clazz, Pageable pageable, ScrollId scrollId);
 11 
 12     public abstract <R> List<R> find(T query, Class<R> clazz, int size);
 13 
 14     public abstract <R> Page<R> find(T query, Class<R> clazz, Pageable pageable);
 15 
 16     public abstract <R> R sum(T query, Class<R> clazz);
 17 
 18     protected Document getDocument(T t) {
 19         Document annotation = t.getClass().getAnnotation(Document.class);
 20         if (annotation == null) {
 21             throw new SearchQueryBuildException("Can't find annotation @Document on " + t.getClass().getName());
 22         }
 23         return annotation;
 24     }
 25 
 26     /**
 27      * 获取字段名,若设置column则返回该值
 28      *
 29      * @param field
 30      * @param column
 31      * @return
 32      */
 33     protected String getFieldName(Field field, String column) {
 34         return StringUtils.isNotBlank(column) ? column : field.getName();
 35     }
 36 
 37     /**
 38      * 设置属性值
 39      *
 40      * @param field
 41      * @param obj
 42      * @param value
 43      */
 44     protected void setFieldValue(Field field, Object obj, Object value) {
 45         boolean isAccessible = field.isAccessible();
 46         field.setAccessible(true);
 47         try {
 48             switch (field.getType().getSimpleName()) {
 49                 case "BigDecimal":
 50                     field.set(obj, new BigDecimal(value.toString()).setScale(5, BigDecimal.ROUND_HALF_UP));
 51                     break;
 52                 case "Long":
 53                     field.set(obj, new Long(value.toString()));
 54                     break;
 55                 case "Integer":
 56                     field.set(obj, new Integer(value.toString()));
 57                     break;
 58                 case "Date":
 59                     field.set(obj, new Date(Long.valueOf(value.toString())));
 60                     break;
 61                 default:
 62                     field.set(obj, value);
 63             }
 64         } catch (IllegalAccessException e) {
 65             throw new SearchQueryBuildException(e);
 66         } finally {
 67             field.setAccessible(isAccessible);
 68         }
 69     }
 70 
 71     /**
 72      * 获取字段值
 73      *
 74      * @param field
 75      * @param obj
 76      * @return
 77      */
 78     protected Object getFieldValue(Field field, Object obj) {
 79         boolean isAccessible = field.isAccessible();
 80         field.setAccessible(true);
 81         try {
 82             return field.get(obj);
 83         } catch (IllegalAccessException e) {
 84             throw new SearchQueryBuildException(e);
 85         } finally {
 86             field.setAccessible(isAccessible);
 87         }
 88     }
 89 
 90     /**
 91      * 转换为es识别的value值
 92      *
 93      * @param value
 94      * @return
 95      */
 96     protected Object formatValue(Object value) {
 97         if (value instanceof Date) {
 98             return ((Date) value).getTime();
 99         } else {
100             return value;
101         }
102     }
103 
104     /**
105      * 获取索引分区数
106      *
107      * @param t
108      * @return
109      */
110     protected int getNumberOfShards(T t) {
111         return Integer.parseInt(elasticsearchTemplate.getSetting(getDocument(t).index()).get(IndexMetaData.SETTING_NUMBER_OF_SHARDS).toString());
112     }
113 }

 

2.接口实现

  1 @Component
  2 @ComponentScan
  3 public class SimpleSearchQueryEngine<T> extends SearchQueryEngine<T> {
  4 
  5     private int numberOfRowsPerScan = 10;
  6 
  7     @Override
  8     public int saveOrUpdate(List<T> list) {
  9         if (CollectionUtils.isEmpty(list)) {
 10             return 0;
 11         }
 12 
 13         T base = list.get(0);
 14         Field id = null;
 15         for (Field field : base.getClass().getDeclaredFields()) {
 16             BusinessID businessID = field.getAnnotation(BusinessID.class);
 17             if (businessID != null) {
 18                 id = field;
 19                 break;
 20             }
 21         }
 22         if (id == null) {
 23             throw new SearchQueryBuildException("Can't find @BusinessID on " + base.getClass().getName());
 24         }
 25 
 26         Document document = getDocument(base);
 27         List<UpdateQuery> bulkIndex = new ArrayList<>();
 28         for (T t : list) {
 29             UpdateQuery updateQuery = new UpdateQuery();
 30             updateQuery.setIndexName(document.index());
 31             updateQuery.setType(document.type());
 32             updateQuery.setId(getFieldValue(id, t).toString());
 33             updateQuery.setUpdateRequest(new UpdateRequest(updateQuery.getIndexName(), updateQuery.getType(), updateQuery.getId()).doc(JSONObject.toJSONString(t, SerializerFeature.WriteMapNullValue)));
 34             updateQuery.setDoUpsert(true);
 35             updateQuery.setClazz(t.getClass());
 36             bulkIndex.add(updateQuery);
 37         }
 38         elasticsearchTemplate.bulkUpdate(bulkIndex);
 39         return list.size();
 40     }
 41 
 42     @Override
 43     public <R> List<R> aggregation(T query, Class<R> clazz) {
 44         NativeSearchQueryBuilder nativeSearchQueryBuilder = buildNativeSearchQueryBuilder(query);
 45         nativeSearchQueryBuilder.addAggregation(buildGroupBy(query));
 46         Aggregations aggregations = elasticsearchTemplate.query(nativeSearchQueryBuilder.build(), new AggregationResultsExtractor());
 47         try {
 48             return transformList(null, aggregations, clazz.newInstance(), new ArrayList());
 49         } catch (Exception e) {
 50             throw new SearchResultBuildException(e);
 51         }
 52     }
 53 
 54     /**
 55      * 将Aggregations转为List
 56      *
 57      * @param terms
 58      * @param aggregations
 59      * @param baseObj
 60      * @param resultList
 61      * @param <R>
 62      * @return
 63      * @throws NoSuchFieldException
 64      * @throws IllegalAccessException
 65      * @throws InstantiationException
 66      */
 67     private <R> List<R> transformList(Aggregation terms, Aggregations aggregations, R baseObj, List<R> resultList) throws NoSuchFieldException, IllegalAccessException, InstantiationException {
 68         for (String column : aggregations.asMap().keySet()) {
 69             Aggregation childAggregation = aggregations.get(column);
 70             if (childAggregation instanceof InternalSum) {
 71                 // 使用@Sum
 72                 if (!(terms instanceof InternalSum)) {
 73                     R targetObj = (R) baseObj.getClass().newInstance();
 74                     BeanUtils.copyProperties(baseObj, targetObj);
 75                     resultList.add(targetObj);
 76                 }
 77                 setFieldValue(baseObj.getClass().getDeclaredField(column), resultList.get(resultList.size() - 1), ((InternalSum) childAggregation).getValue());
 78                 terms = childAggregation;
 79             } else {
 80                 Terms childTerms = (Terms) childAggregation;
 81                 for (Terms.Bucket bucket : childTerms.getBuckets()) {
 82                     if (CollectionUtils.isEmpty(bucket.getAggregations().asList())) {
 83                         // 未使用@Sum
 84                         R targetObj = (R) baseObj.getClass().newInstance();
 85                         BeanUtils.copyProperties(baseObj, targetObj);
 86                         setFieldValue(targetObj.getClass().getDeclaredField(column), targetObj, bucket.getKey());
 87                         resultList.add(targetObj);
 88                     } else {
 89                         setFieldValue(baseObj.getClass().getDeclaredField(column), baseObj, bucket.getKey());
 90                         transformList(childTerms, bucket.getAggregations(), baseObj, resultList);
 91                     }
 92                 }
 93             }
 94         }
 95         return resultList;
 96     }
 97 
 98     @Override
 99     public <R> Page<R> scroll(T query, Class<R> clazz, Pageable pageable, ScrollId scrollId) {
100         if (pageable.getPageSize() % numberOfRowsPerScan > 0) {
101             throw new SearchQueryBuildException("Page size must be an integral multiple of " + numberOfRowsPerScan);
102         }
103         SearchQuery searchQuery = buildNativeSearchQueryBuilder(query).withPageable(new PageRequest(pageable.getPageNumber(), numberOfRowsPerScan / getNumberOfShards(query), pageable.getSort())).build();
104         if (StringUtils.isEmpty(scrollId.getValue())) {
105             scrollId.setValue(elasticsearchTemplate.scan(searchQuery, 10000l, false));
106         }
107         Page<R> page = elasticsearchTemplate.scroll(scrollId.getValue(), 10000l, clazz);
108         if (page == null || page.getContent().size() == 0) {
109             elasticsearchTemplate.clearScroll(scrollId.getValue());
110         }
111         return page;
112     }
113 
114     @Override
115     public <R> List<R> find(T query, Class<R> clazz, int size) {
116         // Caused by: QueryPhaseExecutionException[Result window is too large, from + size must be less than or equal to: [10000] but was [2147483647].
117         // See the scroll api for a more efficient way to request large data sets. This limit can be set by changing the [index.max_result_window] index level parameter.]
118         if (size % numberOfRowsPerScan > 0) {
119             throw new SearchQueryBuildException("Parameter 'size' must be an integral multiple of " + numberOfRowsPerScan);
120         }
121         int pageNum = 0;
122         List<R> result = new ArrayList<>();
123         ScrollId scrollId = new ScrollId();
124         while (true) {
125             Page<R> page = scroll(query, clazz, new PageRequest(pageNum, numberOfRowsPerScan), scrollId);
126             if (page != null && page.getContent().size() > 0) {
127                 result.addAll(page.getContent());
128             } else {
129                 break;
130             }
131             if (result.size() >= size) {
132                 break;
133             } else {
134                 pageNum++;
135             }
136         }
137         elasticsearchTemplate.clearScroll(scrollId.getValue());
138         return result;
139     }
140 
141     @Override
142     public <R> Page<R> find(T query, Class<R> clazz, Pageable pageable) {
143         NativeSearchQueryBuilder nativeSearchQueryBuilder = buildNativeSearchQueryBuilder(query).withPageable(pageable);
144         return elasticsearchTemplate.queryForPage(nativeSearchQueryBuilder.build(), clazz);
145     }
146 
147     @Override
148     public <R> R sum(T query, Class<R> clazz) {
149         NativeSearchQueryBuilder nativeSearchQueryBuilder = buildNativeSearchQueryBuilder(query);
150         for (SumBuilder sumBuilder : getSumBuilderList(query)) {
151             nativeSearchQueryBuilder.addAggregation(sumBuilder);
152         }
153         Aggregations aggregations = elasticsearchTemplate.query(nativeSearchQueryBuilder.build(), new AggregationResultsExtractor());
154         try {
155             return transformSumResult(aggregations, clazz);
156         } catch (Exception e) {
157             throw new SearchResultBuildException(e);
158         }
159     }
160 
161     private <R> R transformSumResult(Aggregations aggregations, Class<R> clazz) throws IllegalAccessException, InstantiationException, NoSuchFieldException {
162         R targetObj = clazz.newInstance();
163         for (Aggregation sum : aggregations.asList()) {
164             if (sum instanceof InternalSum) {
165                 setFieldValue(targetObj.getClass().getDeclaredField(sum.getName()), targetObj, ((InternalSum) sum).getValue());
166             }
167         }
168         return targetObj;
169     }
170 
171     private NativeSearchQueryBuilder buildNativeSearchQueryBuilder(T query) {
172         Document document = getDocument(query);
173         NativeSearchQueryBuilder nativeSearchQueryBuilder = new NativeSearchQueryBuilder()
174                 .withIndices(document.index())
175                 .withTypes(document.type());
176 
177         QueryBuilder whereBuilder = buildBoolQuery(query);
178         if (whereBuilder != null) {
179             nativeSearchQueryBuilder.withQuery(whereBuilder);
180         }
181 
182         return nativeSearchQueryBuilder;
183     }
184 
185     /**
186      * 布尔查询构建
187      *
188      * @param query
189      * @return
190      */
191     private BoolQueryBuilder buildBoolQuery(T query) {
192         BoolQueryBuilder boolQueryBuilder = boolQuery();
193         buildMatchQuery(boolQueryBuilder, query);
194         buildRangeQuery(boolQueryBuilder, query);
195         BoolQueryBuilder queryBuilder = boolQuery().must(boolQueryBuilder);
196         return queryBuilder;
197     }
198 
199     /**
200      * and or 查询构建
201      *
202      * @param boolQueryBuilder
203      * @param query
204      */
205     private void buildMatchQuery(BoolQueryBuilder boolQueryBuilder, T query) {
206         Class clazz = query.getClass();
207         for (Field field : clazz.getDeclaredFields()) {
208             MatchQuery annotation = field.getAnnotation(MatchQuery.class);
209             Object value = getFieldValue(field, query);
210             if (annotation == null || value == null) {
211                 continue;
212             }
213             if (Container.must.equals(annotation.container())) {
214                 boolQueryBuilder.must(matchQuery(getFieldName(field, annotation.column()), formatValue(value)));
215             } else if (should.equals(annotation.container())) {
216                 if (value instanceof Collection) {
217                     BoolQueryBuilder shouldQueryBuilder = boolQuery();
218                     Collection tmp = (Collection) value;
219                     for (Object obj : tmp) {
220                         shouldQueryBuilder.should(matchQuery(getFieldName(field, annotation.column()), formatValue(obj)));
221                     }
222                     boolQueryBuilder.must(shouldQueryBuilder);
223                 } else {
224                     boolQueryBuilder.must(boolQuery().should(matchQuery(getFieldName(field, annotation.column()), formatValue(value))));
225                 }
226             }
227         }
228     }
229 
230     /**
231      * 范围查询构建
232      *
233      * @param boolQueryBuilder
234      * @param query
235      */
236     private void buildRangeQuery(BoolQueryBuilder boolQueryBuilder, T query) {
237         Class clazz = query.getClass();
238         for (Field field : clazz.getDeclaredFields()) {
239             RangeQuery annotation = field.getAnnotation(RangeQuery.class);
240             Object value = getFieldValue(field, query);
241             if (annotation == null || value == null) {
242                 continue;
243             }
244             if (Operator.gt.equals(annotation.operator())) {
245                 boolQueryBuilder.must(rangeQuery(getFieldName(field, annotation.column())).gt(formatValue(value)));
246             } else if (Operator.gte.equals(annotation.operator())) {
247                 boolQueryBuilder.must(rangeQuery(getFieldName(field, annotation.column())).gte(formatValue(value)));
248             } else if (Operator.lt.equals(annotation.operator())) {
249                 boolQueryBuilder.must(rangeQuery(getFieldName(field, annotation.column())).lt(formatValue(value)));
250             } else if (Operator.lte.equals(annotation.operator())) {
251                 boolQueryBuilder.must(rangeQuery(getFieldName(field, annotation.column())).lte(formatValue(value)));
252             }
253         }
254     }
255 
256     /**
257      * Sum构建
258      *
259      * @param query
260      * @return
261      */
262     private List<SumBuilder> getSumBuilderList(T query) {
263         List<SumBuilder> list = new ArrayList<>();
264         Class clazz = query.getClass();
265         for (Field field : clazz.getDeclaredFields()) {
266             Sum annotation = field.getAnnotation(Sum.class);
267             if (annotation == null) {
268                 continue;
269             }
270             list.add(AggregationBuilders.sum(field.getName()).field(field.getName()));
271         }
272         if (CollectionUtils.isEmpty(list)) {
273             throw new SearchQueryBuildException("Can't find @Sum on " + clazz.getName());
274         }
275         return list;
276     }
277 
278 
279     /**
280      * GroupBy构建
281      *
282      * @param query
283      * @return
284      */
285     private TermsBuilder buildGroupBy(T query) {
286         List<Field> sumList = new ArrayList<>();
287         Object groupByCollection = null;
288         Class clazz = query.getClass();
289         for (Field field : clazz.getDeclaredFields()) {
290             Sum sumAnnotation = field.getAnnotation(Sum.class);
291             if (sumAnnotation != null) {
292                 sumList.add(field);
293             }
294             GroupBy groupByannotation = field.getAnnotation(GroupBy.class);
295             Object value = getFieldValue(field, query);
296             if (groupByannotation == null || value == null) {
297                 continue;
298             } else if (!(value instanceof Collection)) {
299                 throw new SearchQueryBuildException("GroupBy filed must be collection");
300             } else if (CollectionUtils.isEmpty((Collection<String>) value)) {
301                 continue;
302             } else if (groupByCollection != null) {
303                 throw new SearchQueryBuildException("Only one @GroupBy is allowed");
304             } else {
305                 groupByCollection = value;
306             }
307         }
308         Iterator<String> iterator = ((Collection<String>) groupByCollection).iterator();
309         TermsBuilder termsBuilder = recursiveAddAggregation(iterator, sumList);
310         return termsBuilder;
311     }
312 
313     /**
314      * 添加Aggregation
315      *
316      * @param iterator
317      * @return
318      */
319     private TermsBuilder recursiveAddAggregation(Iterator<String> iterator, List<Field> sumList) {
320         String groupBy = iterator.next();
321         TermsBuilder termsBuilder = AggregationBuilders.terms(groupBy).field(groupBy).size(0);
322         if (iterator.hasNext()) {
323             termsBuilder.subAggregation(recursiveAddAggregation(iterator, sumList));
324         } else {
325             for (Field field : sumList) {
326                 termsBuilder.subAggregation(AggregationBuilders.sum(field.getName()).field(field.getName()));
327             }
328             sumList.clear();
329         }
330         return termsBuilder.order(Terms.Order.term(true));
331     }

3.存储scrollId值对象

import lombok.Data;

@Data
public class ScrollId {

    private String value;

}

4.用于判断查询操作的枚举类

public enum Operator {
    gt, gte, lt, lte
}
public enum Container {
    must, should
}

 

posted @ 2019-01-17 09:56  肖哥哥  阅读(10331)  评论(0编辑  收藏  举报
生命不息  奋斗不止  每天进步一点点