/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.action.mcpserver;

import java.time.Instant;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.DocWriteRequest;
import org.opensearch.action.bulk.BulkItemResponse;
import org.opensearch.action.bulk.BulkRequest;
import org.opensearch.action.bulk.BulkResponse;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.action.mcpserver.McpToolsHelper;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.MLIndex;
import org.opensearch.ml.common.settings.MLCommonsSettings;
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.common.transport.mcpserver.action.MLMcpToolsRegisterOnNodesAction;
import org.opensearch.ml.common.transport.mcpserver.requests.McpToolBaseInput;
import org.opensearch.ml.common.transport.mcpserver.requests.register.MLMcpToolsRegisterNodesRequest;
import org.opensearch.ml.common.transport.mcpserver.requests.register.McpToolRegisterInput;
import org.opensearch.ml.common.transport.mcpserver.responses.register.MLMcpToolsRegisterNodesResponse;
import org.opensearch.ml.engine.indices.MLIndicesHandler;
import org.opensearch.tasks.Task;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;
import org.opensearch.transport.client.Client;

public class TransportMcpToolsRegisterAction
extends HandledTransportAction<ActionRequest, MLMcpToolsRegisterNodesResponse> {
    @Generated
    private static final Logger log = LogManager.getLogger(TransportMcpToolsRegisterAction.class);
    TransportService transportService;
    ClusterService clusterService;
    ThreadPool threadPool;
    Client client;
    NamedXContentRegistry xContentRegistry;
    DiscoveryNodeHelper nodeFilter;
    private MLFeatureEnabledSetting mlFeatureEnabledSetting;
    private final MLIndicesHandler mlIndicesHandler;
    private final McpToolsHelper mcpToolsHelper;

    @Inject
    public TransportMcpToolsRegisterAction(TransportService transportService, ActionFilters actionFilters, ClusterService clusterService, ThreadPool threadPool, Client client, NamedXContentRegistry xContentRegistry, DiscoveryNodeHelper nodeFilter, MLIndicesHandler mlIndicesHandler, McpToolsHelper mcpToolsHelper, MLFeatureEnabledSetting mlFeatureEnabledSetting) {
        super("cluster:admin/opensearch/ml/mcp_tools/register", transportService, actionFilters, MLMcpToolsRegisterNodesRequest::new);
        this.transportService = transportService;
        this.clusterService = clusterService;
        this.threadPool = threadPool;
        this.client = client;
        this.xContentRegistry = xContentRegistry;
        this.nodeFilter = nodeFilter;
        this.mlIndicesHandler = mlIndicesHandler;
        this.mcpToolsHelper = mcpToolsHelper;
        this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
    }

    protected void doExecute(Task task, ActionRequest request, ActionListener<MLMcpToolsRegisterNodesResponse> listener) {
        if (!this.mlFeatureEnabledSetting.isMcpServerEnabled()) {
            listener.onFailure((Exception)new OpenSearchException(MLCommonsSettings.ML_COMMONS_MCP_SERVER_DISABLED_MESSAGE, new Object[0]));
            return;
        }
        MLMcpToolsRegisterNodesRequest registerNodesRequest = (MLMcpToolsRegisterNodesRequest)request;
        try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
            ActionListener restoreListener = ActionListener.runBefore(listener, () -> ((ThreadContext.StoredContext)context).restore());
            ActionListener initIndexListener = ActionListener.wrap(created -> {
                ActionListener searchResultListener = ActionListener.wrap(searchResult -> {
                    if (!searchResult.isEmpty()) {
                        Set registerToolNames = registerNodesRequest.getMcpTools().stream().map(McpToolBaseInput::getName).collect(Collectors.toSet());
                        List<String> existingTools = searchResult.stream().map(McpToolBaseInput::getName).filter(registerToolNames::contains).toList();
                        String exceptionMessage = String.format(Locale.ROOT, "Unable to register tools: %s as they already exist", existingTools);
                        log.warn(exceptionMessage);
                        restoreListener.onFailure((Exception)new OpenSearchException(exceptionMessage, new Object[0]));
                    } else {
                        this.indexMcpTools(registerNodesRequest, (ActionListener<MLMcpToolsRegisterNodesResponse>)restoreListener);
                    }
                }, e -> {
                    log.error("Failed to search mcp tools index", (Throwable)e);
                    restoreListener.onFailure(e);
                });
                this.mcpToolsHelper.searchToolsWithVersion(registerNodesRequest.getMcpTools().stream().map(McpToolBaseInput::getName).toList(), (ActionListener<List<McpToolRegisterInput>>)searchResultListener);
            }, e -> {
                log.error("Failed to create .plugins-ml-mcp-tools index", (Throwable)e);
                restoreListener.onFailure(e);
            });
            this.mlIndicesHandler.initMLMcpToolsIndex(initIndexListener);
        }
        catch (Exception e2) {
            log.error("Failed to register mcp tools", (Throwable)e2);
            listener.onFailure(e2);
        }
    }

    private void indexMcpTools(MLMcpToolsRegisterNodesRequest registerNodesRequest, ActionListener<MLMcpToolsRegisterNodesResponse> listener) {
        try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
            ActionListener restoreListener = ActionListener.runBefore(listener, () -> ((ThreadContext.StoredContext)context).restore());
            ActionListener indexResultListener = ActionListener.wrap(bulkResponse -> {
                if (!bulkResponse.hasFailures()) {
                    this.registerMcpToolsOnNodes(new StringBuilder(), this.updateVersion(registerNodesRequest, (BulkResponse)bulkResponse), registerNodesRequest.getMcpTools().stream().map(McpToolBaseInput::getName).collect(Collectors.toUnmodifiableSet()), (ActionListener<MLMcpToolsRegisterNodesResponse>)restoreListener);
                } else {
                    AtomicReference indexSucceedTools = new AtomicReference();
                    indexSucceedTools.set(new HashSet());
                    AtomicReference indexFailedTools = new AtomicReference();
                    indexFailedTools.set(new HashMap());
                    Arrays.stream(bulkResponse.getItems()).forEach(y -> {
                        if (y.isFailed()) {
                            ((Map)indexFailedTools.get()).put(y.getId(), y.getFailure().getMessage());
                            registerNodesRequest.getMcpTools().removeIf(x -> x.getName().equals(y.getId()));
                        } else {
                            ((Set)indexSucceedTools.get()).add(y.getId());
                        }
                    });
                    StringBuilder errMsgBuilder = new StringBuilder();
                    for (Map.Entry indexFailedTool : ((Map)indexFailedTools.get()).entrySet()) {
                        errMsgBuilder.append(String.format(Locale.ROOT, "Failed to persist mcp tool: %s into system index with error: %s", indexFailedTool.getKey(), indexFailedTool.getValue()));
                        errMsgBuilder.append("\n");
                    }
                    log.error(errMsgBuilder.toString());
                    if (!((Set)indexSucceedTools.get()).isEmpty()) {
                        this.registerMcpToolsOnNodes(errMsgBuilder, this.updateVersion(registerNodesRequest, (BulkResponse)bulkResponse), (Set)indexSucceedTools.get(), (ActionListener<MLMcpToolsRegisterNodesResponse>)restoreListener);
                    } else {
                        restoreListener.onFailure((Exception)new OpenSearchException(errMsgBuilder.deleteCharAt(errMsgBuilder.length() - 1).toString(), new Object[0]));
                    }
                }
            }, e -> {
                log.error("Failed to persist mcp tools into system index because exception: {}", (Object)e.getMessage());
                restoreListener.onFailure(e);
            });
            BulkRequest bulkRequest = new BulkRequest();
            bulkRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
            for (McpToolRegisterInput mcpTool : registerNodesRequest.getMcpTools()) {
                IndexRequest indexRequest = new IndexRequest(MLIndex.MCP_TOOLS.getIndexName());
                indexRequest.opType(DocWriteRequest.OpType.CREATE);
                indexRequest.id(mcpTool.getName());
                HashMap<String, Object> source = new HashMap<String, Object>();
                source.put("name", mcpTool.getName());
                source.put("type", mcpTool.getType());
                source.put("parameters", mcpTool.getParameters());
                source.put("attributes", mcpTool.getAttributes());
                source.put("description", mcpTool.getDescription());
                source.put("create_time", Instant.now().toEpochMilli());
                indexRequest.source(source);
                bulkRequest.add(indexRequest);
            }
            this.client.bulk(bulkRequest, indexResultListener);
        }
        catch (Exception e2) {
            log.error("Failed to register mcp tools", (Throwable)e2);
            listener.onFailure(e2);
        }
    }

    private MLMcpToolsRegisterNodesRequest updateVersion(MLMcpToolsRegisterNodesRequest registerNodesRequest, BulkResponse bulkResponse) {
        Map<String, Long> version = Arrays.stream(bulkResponse.getItems()).filter(x -> !x.isFailed()).collect(Collectors.toMap(BulkItemResponse::getId, x -> x.getResponse().getVersion()));
        registerNodesRequest.getMcpTools().forEach(x -> x.setVersion((Long)version.get(x.getName())));
        return registerNodesRequest;
    }

    private void registerMcpToolsOnNodes(StringBuilder errMsgBuilder, MLMcpToolsRegisterNodesRequest registerNodesRequest, Set<String> indexSucceedTools, ActionListener<MLMcpToolsRegisterNodesResponse> listener) {
        try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
            ActionListener restoreListener = ActionListener.runBefore(listener, () -> ((ThreadContext.StoredContext)context).restore());
            ActionListener addToMemoryResultListener = ActionListener.wrap(r -> {
                if (r.failures() != null && !r.failures().isEmpty()) {
                    r.failures().forEach(x -> {
                        errMsgBuilder.append(String.format(Locale.ROOT, "Tools: %s are persisted successfully but failed to register to mcp server memory with error: %s", indexSucceedTools, x.getRootCause().getMessage()));
                        errMsgBuilder.append("\n");
                    });
                    errMsgBuilder.deleteCharAt(errMsgBuilder.length() - 1);
                    log.error(errMsgBuilder.toString());
                    restoreListener.onFailure((Exception)new OpenSearchException(errMsgBuilder.toString(), new Object[0]));
                } else if (errMsgBuilder.isEmpty()) {
                    restoreListener.onResponse(r);
                } else {
                    restoreListener.onFailure((Exception)new OpenSearchException(errMsgBuilder.deleteCharAt(errMsgBuilder.length() - 1).toString(), new Object[0]));
                }
            }, e -> {
                errMsgBuilder.append(String.format(Locale.ROOT, "Tools are persisted successfully but failed to register to mcp server memory with error: %s", e.getMessage()));
                log.error(errMsgBuilder.toString(), (Throwable)e);
                restoreListener.onFailure((Exception)new OpenSearchException(errMsgBuilder.toString(), new Object[0]));
            });
            this.client.execute((ActionType)MLMcpToolsRegisterOnNodesAction.INSTANCE, (ActionRequest)registerNodesRequest, addToMemoryResultListener);
        }
        catch (Exception e2) {
            log.error("Failed to register mcp tools on nodes", (Throwable)e2);
            listener.onFailure(e2);
        }
    }
}

