package com.bcxin.flink.cdc.kafka.source.task.cdcs.dynamic;

import com.bcxin.flink.cdc.kafka.source.task.FlinkConstants;
import com.bcxin.event.core.JsonProvider;
import com.bcxin.event.core.JsonProviderImpl;
import com.bcxin.event.core.exceptions.BadEventException;
import com.mongodb.MongoException;
import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoClients;
import com.mongodb.client.MongoCollection;
import com.mongodb.client.model.Filters;
import com.mongodb.client.model.InsertOneModel;
import com.mongodb.client.model.ReplaceOneModel;
import com.mongodb.client.model.WriteModel;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.api.common.functions.util.ListCollector;
import org.apache.flink.api.common.operators.MailboxExecutor;
import org.apache.flink.api.common.serialization.SerializationSchema;
import org.apache.flink.api.connector.sink2.Sink;
import org.apache.flink.api.connector.sink2.SinkWriter;
import org.apache.flink.connector.mongodb.common.config.MongoConnectionOptions;
import org.apache.flink.connector.mongodb.sink.config.MongoWriteOptions;
import org.apache.flink.connector.mongodb.sink.writer.MongoWriter;
import org.apache.flink.connector.mongodb.sink.writer.context.DefaultMongoSinkContext;
import org.apache.flink.connector.mongodb.sink.writer.context.MongoSinkContext;
import org.apache.flink.connector.mongodb.sink.writer.serializer.MongoSerializationSchema;
import org.apache.flink.metrics.Counter;
import org.apache.flink.metrics.groups.SinkWriterMetricGroup;
import org.apache.flink.util.Collector;
import org.apache.flink.util.FlinkRuntimeException;
import org.bson.*;
import org.bson.conversions.Bson;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.*;
import java.util.stream.Collectors;

import static org.apache.flink.util.Preconditions.checkNotNull;

/**
 * 从MongoWriter
 * @param <IN>
 */
public class DynamicDbCollectionMongoWriter<IN> implements SinkWriter<IN> {

    private static final Logger LOG = LoggerFactory.getLogger(MongoWriter.class);

    private final MongoConnectionOptions connectionOptions;
    private final MongoWriteOptions writeOptions;
    private final MongoSerializationSchema<IN> serializationSchema;
    private final MongoSinkContext sinkContext;
    private final MailboxExecutor mailboxExecutor;
    private final boolean flushOnCheckpoint;
    private final List<WriteModel<BsonDocument>> bulkRequests = new ArrayList<>();
    private final Collector<WriteModel<BsonDocument>> collector;
    private final Counter numRecordsOut;
    private final MongoClient mongoClient;

    private boolean checkpointInProgress = false;
    private volatile long lastSendTime = 0L;
    private volatile long ackTime = Long.MAX_VALUE;
    private final JsonProvider jsonProvider;

    public DynamicDbCollectionMongoWriter(
            MongoConnectionOptions connectionOptions,
            MongoWriteOptions writeOptions,
            boolean flushOnCheckpoint,
            Sink.InitContext initContext,
            MongoSerializationSchema<IN> serializationSchema) {
        this.connectionOptions = checkNotNull(connectionOptions);
        this.writeOptions = checkNotNull(writeOptions);
        this.serializationSchema = checkNotNull(serializationSchema);
        this.flushOnCheckpoint = flushOnCheckpoint;

        checkNotNull(initContext);
        this.mailboxExecutor = checkNotNull(initContext.getMailboxExecutor());

        SinkWriterMetricGroup metricGroup = checkNotNull(initContext.metricGroup());
        metricGroup.setCurrentSendTimeGauge(() -> ackTime - lastSendTime);

        this.numRecordsOut = metricGroup.getNumRecordsSendCounter();
        this.collector = new ListCollector<>(this.bulkRequests);

        this.sinkContext = new DefaultMongoSinkContext(initContext, writeOptions);
        try {
            SerializationSchema.InitializationContext initializationContext =
                    initContext.asSerializationSchemaInitializationContext();
            serializationSchema.open(initializationContext, sinkContext, writeOptions);
        } catch (Exception e) {
            throw new FlinkRuntimeException("Failed to open the MongoEmitter", e);
        }

        // Initialize the mongo client.
        this.mongoClient = MongoClients.create(connectionOptions.getUri());
        this.jsonProvider = new JsonProviderImpl();
    }

    @Override
    public void write(IN element, Context context) throws IOException, InterruptedException {
        // do not allow new bulk writes until all actions are flushed
        while (checkpointInProgress) {
            mailboxExecutor.yield();
        }
        WriteModel<BsonDocument> writeModel = serializationSchema.serialize(element, sinkContext);
        numRecordsOut.inc();
        collector.collect(writeModel);
        if (isOverMaxBatchSizeLimit() || isOverMaxBatchIntervalLimit()) {
            doBulkWrite();
        }
    }

    @Override
    public void flush(boolean endOfInput) throws IOException {
        checkpointInProgress = true;
        while (!bulkRequests.isEmpty() && (flushOnCheckpoint || endOfInput)) {
            doBulkWrite();
        }
        checkpointInProgress = false;
    }

    @Override
    public void close() {
        mongoClient.close();
    }

    @VisibleForTesting
    void doBulkWrite() throws IOException {
        if (bulkRequests.isEmpty()) {
            // no records to write
            return;
        }

        int maxRetries = writeOptions.getMaxRetries();
        long retryIntervalMs = writeOptions.getRetryIntervalMs();
        for (int i = 0; i <= maxRetries; i++) {
            try {
                lastSendTime = System.currentTimeMillis();
                executeBulkRequests(bulkRequests);

                ackTime = System.currentTimeMillis();
                bulkRequests.clear();
                break;
            } catch (MongoException e) {
                LOG.debug("Bulk Write to MongoDB failed, retry times = {}", i, e);
                if (i >= maxRetries) {
                    LOG.error("Bulk Write to MongoDB failed", e);
                    throw new IOException(e);
                }
                try {
                    Thread.sleep(retryIntervalMs * (i + 1));
                } catch (InterruptedException ex) {
                    Thread.currentThread().interrupt();
                    throw new IOException(
                            "Unable to flush; interrupted while doing another attempt", e);
                }
            }
        }
    }

    private boolean isOverMaxBatchSizeLimit() {
        int bulkActions = writeOptions.getBatchSize();
        return bulkActions != -1 && bulkRequests.size() >= bulkActions;
    }

    private boolean isOverMaxBatchIntervalLimit() {
        long bulkFlushInterval = writeOptions.getBatchIntervalMs();
        long lastSentInterval = System.currentTimeMillis() - lastSendTime;
        return bulkFlushInterval != -1 && lastSentInterval >= bulkFlushInterval;
    }

    private void executeBulkRequests( List<WriteModel<BsonDocument>> bulkRequests) {
        Map<String, List<WriteModel<BsonDocument>>> groups = new HashMap<>();
        for (WriteModel<BsonDocument> documentWriteModel : bulkRequests) {
            InsertOneModel<BsonDocument> insertOneModel = (InsertOneModel) documentWriteModel;
            BsonDocument document = insertOneModel.getDocument();
            BsonDocument source = document.getDocument("source");
            String db = source.get("db").asString().getValue();
            String table = source.get("table").asString().getValue();

            String groupKey = String.format("%s|%s", db, table);
            List<WriteModel<BsonDocument>> collection =
                    groups.get(groupKey);
            if (collection == null) {
                collection = new ArrayList<>();
                groups.put(groupKey, collection);
            }
            BsonValue beforeDataNode = document.get("before");
            BsonValue afterDataNode = document.get("after");
            BsonValue dataNode = afterDataNode;

            /**
             * 优先认After的数据
             */
            if (dataNode == null || dataNode instanceof BsonNull) {
                dataNode = beforeDataNode;
            }

            BsonDocument contentDocument = null;
            if (dataNode != null) {
                contentDocument = dataNode.asDocument();


                BsonDocument dataBodyNode = ((BsonDocument) dataNode);
                Optional<String> keyOptional =
                        dataBodyNode.keySet().stream().filter(ix ->
                                ix.equalsIgnoreCase("id") ||
                                        ix.equalsIgnoreCase("pk_id") ||
                                        ix.equalsIgnoreCase("pkId")
                        ).findFirst();
                if (!keyOptional.isPresent()) {
                    throw new BadEventException(String.format("无效数据; 找不到节点主Id=%s", this.jsonProvider.getJson(document)));
                }

                BsonValue keyValue = dataBodyNode.get(keyOptional.get());

                Collection<String> dateTimeKeys =
                        contentDocument.keySet().stream().filter(ix -> FlinkConstants.isDateTimeField(ix))
                                .collect(Collectors.toList());

                for(String key:dateTimeKeys) {
                    BsonValue bsonKeyValue = dataBodyNode.get(key);

                    if (bsonKeyValue instanceof BsonNumber) {
                        Object value = FlinkConstants.formatValue(key, bsonKeyValue.asNumber());
                        if (!(value instanceof BsonValue)) {
                            contentDocument.put(key, new BsonString(String.valueOf(value)));
                        }
                    } else if (bsonKeyValue == null || bsonKeyValue instanceof BsonNull) {
                        contentDocument.put(key, bsonKeyValue);
                    }
                }

                contentDocument.put("_id", keyValue);
                contentDocument.put("__meta.source", source);
                contentDocument.put("__meta.op", document.get("op"));
                contentDocument.put("__meta.ts_ms", document.get("ts_ms"));
                if (beforeDataNode != dataNode) {
                    contentDocument.put("__meta.before", beforeDataNode);
                }

                contentDocument.put("__meta.sync_time", new BsonString(new Date().toString()));
                Bson bsonIdFilter = Filters.eq("_id", keyValue);
                WriteModel<BsonDocument> replaceOneDocument = new ReplaceOneModel<>(bsonIdFilter, contentDocument);
                collection.add(replaceOneDocument);
            } else {
                LOG.error("无效数据; 找不到节点数据:{}", document);
            }
        }

        for (String dbTableKey : groups.keySet()) {
            String[] dbTableArray = dbTableKey.split("\\|");
            String db = dbTableArray[0].toLowerCase();
            String table = dbTableArray[1].toLowerCase();
            List<WriteModel<BsonDocument>> requests = groups.get(dbTableKey);

            MongoCollection<BsonDocument>
                    mongoCollection =
                    mongoClient
                            .getDatabase(db)
                            .getCollection(table, BsonDocument.class);
            mongoCollection
                    .bulkWrite(requests);

            LOG.error("数据同步完毕:{}.{};size={}", db, table, requests.size());
        }
    }
}