AggregationService.java

package com.wilzwert.myjobs.infrastructure.persistence.mongo.service;

import com.wilzwert.myjobs.core.domain.shared.specification.DomainSpecification;
import org.bson.Document;
import org.springframework.data.mongodb.core.MongoTemplate;
import org.springframework.data.mongodb.core.aggregation.*;
import org.springframework.stereotype.Service;

import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;


@Service
public class AggregationService {
    private final MongoTemplate mongoTemplate;

    private final DomainSpecificationConverter converter;

    public AggregationService(MongoTemplate mongoTemplate, DomainSpecificationConverter converter) {
        this.mongoTemplate = mongoTemplate;
        this.converter = converter;
    }

    private Aggregation createAggregation(List<AggregationOperation> operationList) {
        if(operationList.isEmpty()) {
            throw new IllegalArgumentException("Cannot create Aggregation : Operation list is empty");
        }

        return Aggregation.newAggregation(operationList);
    }

    private List<AggregationOperation> domainToOperations(DomainSpecification specification) {
        return converter.convert(specification);
    }

    public Aggregation createAggregation(DomainSpecification specification) {
        return createAggregation(domainToOperations(specification));
    }

    public Aggregation createAggregationPaginated(DomainSpecification specification, int page, int size) {
        Aggregation aggregation = createAggregation(domainToOperations(specification));
        aggregation.getPipeline().add(Aggregation.skip((long) page * size));
        aggregation.getPipeline().add(Aggregation.limit(size));
        return aggregation;
    }

    public <T> List<T> aggregate(Aggregation aggregation, String collectionName, Class<T> outputClass) {
        AggregationResults<T> results = mongoTemplate.aggregate(aggregation, collectionName, outputClass);
        return results.getMappedResults();
    }

    public <T> Stream<T> stream(Aggregation aggregation, String collectionName, Class<T> outputClass) {
        return mongoTemplate.aggregateStream(aggregation, collectionName, outputClass);
    }

    /**
     *
     * @param aggregation the Aggregation we want the count for
     * @param collectionName the MongoDB collection name
     * @return the count as long
     */
    public long getAggregationCount(Aggregation aggregation, String collectionName) {
        List<AggregationOperation> stages = aggregation.getPipeline().getOperations().stream()
                .filter(MatchOperation.class::isInstance)
                .collect(Collectors.toList());
        stages.add(Aggregation.count().as("total"));
        AggregationResults<Document> countResults = mongoTemplate.aggregate(Aggregation.newAggregation(stages), collectionName, Document.class);
        Document resultDoc = countResults.getUniqueMappedResult();
        return resultDoc != null ? ((Number) resultDoc.get("total")).longValue() : 0L;
    }
}