package com.bcxin.tenant.open.infrastructure.tx.components;

import com.bcxin.tenant.open.infrastructures.UnitWork;
import com.bcxin.tenant.open.infrastructures.entities.Aggregate;
import org.springframework.transaction.PlatformTransactionManager;
import org.springframework.transaction.TransactionDefinition;
import org.springframework.transaction.TransactionStatus;
import org.springframework.transaction.support.DefaultTransactionDefinition;

import java.util.Collection;
import java.util.Map;
import java.util.UUID;

public class UnitWorkTxImpl implements UnitWork {
    private static ThreadLocal<Map<String, TransactionStatus>> _translationFlags = new InheritableThreadLocal<>();
    private final PlatformTransactionManager transactionManager;

    public UnitWorkTxImpl(PlatformTransactionManager transactionManager) {
        this.transactionManager = transactionManager;
    }

    @Override
    public String beginTransaction() {
        String txId = String.format("%s-%s", UUID.randomUUID(), Thread.currentThread().getId());
        if (_translationFlags.get() == null) {
            DefaultTransactionDefinition def = new DefaultTransactionDefinition();
            // explicitly setting the transaction name is something that can be done only programmatically
            def.setName("SomeTxName");
            def.setPropagationBehavior(TransactionDefinition.PROPAGATION_REQUIRED);

            TransactionStatus transactionStatus = this.transactionManager.getTransaction(TransactionDefinition.withDefaults());

            _translationFlags.set(Map.of(txId, transactionStatus));
        }

        return txId;
    }

    @Override
    public void commit(String tid) {
        if (_translationFlags.get() == null || !_translationFlags.get().containsKey(tid)) {
            return;
        }

        TransactionStatus transactionStatus = _translationFlags.get().get(tid);
        if (transactionStatus != null) {
            this.transactionManager.commit(transactionStatus);

            _translationFlags.set(null);
        }
    }

    @Override
    public void detachAll() {
    }

    @Override
    public void detach(Aggregate aggregate) {

    }

    @Override
    public <T extends Aggregate> void detachAll(Collection<T> aggregates) {

    }

    @Override
    public void executeTran(Runnable runnable) {
        String trId = this.beginTransaction();
        try {
            runnable.run();
            this.commit(trId);
        } catch (Exception ex) {
            this.rollback(trId);
            throw ex;
        }
    }

    @Override
    public void rollback(String tid) {
        if (_translationFlags.get() == null || !_translationFlags.get().containsKey(tid)) {
            return;
        }

        TransactionStatus transactionStatus = _translationFlags.get().get(tid);
        if (transactionStatus != null) {
            this.transactionManager.rollback(transactionStatus);

            _translationFlags.set(null);
        }
    }

    @Override
    public void executeNewTran(Runnable runnable) {
        this.executeTran(runnable);
    }
}
