/*
 * Decompiled with CFR 0.152.
 */
package org.apache.shenyu.plugin.mcp.server.manager;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.Lists;
import io.modelcontextprotocol.server.McpAsyncServer;
import io.modelcontextprotocol.server.McpServer;
import io.modelcontextprotocol.server.McpServerFeatures;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpServerSession;
import io.modelcontextprotocol.spec.McpServerTransportProvider;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Supplier;
import org.apache.shenyu.plugin.mcp.server.callback.ShenyuToolCallback;
import org.apache.shenyu.plugin.mcp.server.definition.ShenyuToolDefinition;
import org.apache.shenyu.plugin.mcp.server.transport.ShenyuSseServerTransportProvider;
import org.apache.shenyu.plugin.mcp.server.transport.ShenyuStreamableHttpServerTransportProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.mcp.McpToolUtils;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.stereotype.Component;
import org.springframework.util.AntPathMatcher;
import org.springframework.web.reactive.function.server.HandlerFunction;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

@Component
public class ShenyuMcpServerManager {
    private static final Logger LOG = LoggerFactory.getLogger(ShenyuMcpServerManager.class);
    private static final String SSE_PROTOCOL = "SSE";
    private static final String STREAMABLE_HTTP_PROTOCOL = "Streamable HTTP";
    private final AntPathMatcher pathMatcher = new AntPathMatcher();
    private final ObjectMapper objectMapper = new ObjectMapper();
    private final Map<String, McpAsyncServer> sharedServerMap = new ConcurrentHashMap<String, McpAsyncServer>();
    private final Map<String, HandlerFunction<?>> routeMap = new ConcurrentHashMap();
    private final Map<String, CompositeTransportProvider> compositeTransportMap = new ConcurrentHashMap<String, CompositeTransportProvider>();

    public ShenyuSseServerTransportProvider getOrCreateMcpServerTransport(String uri, String messageEndPoint) {
        String normalizedPath = this.processPath(uri);
        return this.getOrCreateTransport(normalizedPath, SSE_PROTOCOL, () -> this.createSseTransport(normalizedPath, messageEndPoint));
    }

    public ShenyuStreamableHttpServerTransportProvider getOrCreateStreamableHttpTransport(String uri) {
        String normalizedPath = this.processPath(uri);
        return this.getOrCreateTransport(normalizedPath, STREAMABLE_HTTP_PROTOCOL, () -> this.createStreamableHttpTransport(normalizedPath, uri));
    }

    private <T> T getOrCreateTransport(String normalizedPath, String protocol, Supplier<T> transportFactory) {
        Object transport;
        CompositeTransportProvider compositeTransport = this.getOrCreateCompositeTransport(normalizedPath);
        switch (protocol) {
            case "SSE": {
                Object object = compositeTransport.getSseTransport();
                break;
            }
            case "Streamable HTTP": {
                Object object = compositeTransport.getStreamableHttpTransport();
                break;
            }
            default: {
                Object object = transport = null;
            }
        }
        if (Objects.isNull(transport)) {
            transport = transportFactory.get();
            this.addTransportToSharedServer(normalizedPath, protocol, transport);
        }
        return (T)transport;
    }

    private String processPath(String uri) {
        return this.normalizeServerPath(this.extractBasePath(uri));
    }

    private CompositeTransportProvider getOrCreateCompositeTransport(String normalizedPath) {
        return this.compositeTransportMap.computeIfAbsent(normalizedPath, path -> {
            LOG.debug("Creating composite transport provider for path: {}", path);
            return new CompositeTransportProvider();
        });
    }

    private void addTransportToSharedServer(String normalizedPath, String protocol, Object transportProvider) {
        this.getOrCreateSharedServer(normalizedPath);
        CompositeTransportProvider compositeTransport = this.compositeTransportMap.get(normalizedPath);
        if (Objects.nonNull(compositeTransport)) {
            compositeTransport.addTransport(protocol, transportProvider);
        }
        LOG.info("Added {} transport to shared server at path: {}", (Object)protocol, (Object)normalizedPath);
    }

    private McpAsyncServer getOrCreateSharedServer(String normalizedPath) {
        return this.sharedServerMap.computeIfAbsent(normalizedPath, path -> {
            LOG.info("Creating shared MCP server for path: {}", path);
            CompositeTransportProvider compositeTransport = this.getOrCreateCompositeTransport((String)path);
            McpSchema.ServerCapabilities capabilities = McpSchema.ServerCapabilities.builder().tools(Boolean.valueOf(true)).logging().build();
            McpAsyncServer server = McpServer.async((McpServerTransportProvider)compositeTransport).serverInfo("MCP Shenyu Server (Multi-Protocol)", "1.0.0").capabilities(capabilities).tools((List)Lists.newArrayList()).build();
            LOG.info("Created shared MCP server for path: {} with multi-protocol support", path);
            return server;
        });
    }

    private ShenyuSseServerTransportProvider createSseTransport(String normalizedPath, String messageEndPoint) {
        String messageEndpoint = normalizedPath + messageEndPoint;
        ShenyuSseServerTransportProvider transportProvider = ShenyuSseServerTransportProvider.builder().objectMapper(this.objectMapper).sseEndpoint(normalizedPath).messageEndpoint(messageEndpoint).build();
        this.registerRoutes(normalizedPath, messageEndpoint, transportProvider::handleSseConnection, transportProvider::handleMessage);
        LOG.debug("Created SSE transport for path: {}", (Object)normalizedPath);
        return transportProvider;
    }

    private ShenyuStreamableHttpServerTransportProvider createStreamableHttpTransport(String normalizedPath, String originalUri) {
        ShenyuStreamableHttpServerTransportProvider transportProvider = ShenyuStreamableHttpServerTransportProvider.builder().objectMapper(this.objectMapper).endpoint(originalUri).build();
        this.registerRoutes(originalUri, null, transportProvider::handleUnifiedEndpoint, null);
        LOG.debug("Created Streamable HTTP transport for original URI: {} (normalized: {})", (Object)originalUri, (Object)normalizedPath);
        return transportProvider;
    }

    private void registerRoutes(String primaryPath, String secondaryPath, HandlerFunction<?> primaryHandler, HandlerFunction<?> secondaryHandler) {
        this.routeMap.put(primaryPath, primaryHandler);
        this.routeMap.put(primaryPath + "/**", primaryHandler);
        if (Objects.nonNull(secondaryPath) && Objects.nonNull(secondaryHandler)) {
            this.routeMap.put(secondaryPath, secondaryHandler);
            this.routeMap.put(secondaryPath + "/**", secondaryHandler);
        }
    }

    private String extractBasePath(String uri) {
        String[] pathSegments;
        Object basePath = uri;
        if (((String)basePath).endsWith("/message")) {
            basePath = ((String)basePath).substring(0, ((String)basePath).length() - "/message".length());
        }
        if ((pathSegments = ((String)basePath).split("/")).length > 2) {
            basePath = "/" + pathSegments[1];
        }
        return basePath;
    }

    public boolean hasMcpServer(String uri) {
        String normalizedPath = this.processPath(uri);
        return this.sharedServerMap.containsKey(normalizedPath);
    }

    public boolean canRoute(String uri) {
        if (this.routeMap.containsKey(uri)) {
            return true;
        }
        for (String pattern : this.routeMap.keySet()) {
            if (!this.pathMatcher.match(pattern, uri)) continue;
            LOG.debug("URI '{}' matches pattern '{}'", (Object)uri, (Object)pattern);
            return true;
        }
        return false;
    }

    public void removeMcpServer(String uri) {
        String normalizedPath = this.processPath(uri);
        LOG.info("Removing MCP server for URI: {} (normalized: {})", (Object)uri, (Object)normalizedPath);
        CompositeTransportProvider compositeTransport = this.compositeTransportMap.remove(normalizedPath);
        if (Objects.nonNull(compositeTransport)) {
            compositeTransport.closeGracefully().doOnSuccess(aVoid -> LOG.info("Successfully closed composite transport for path: {}", (Object)normalizedPath)).doOnError(e -> LOG.error("Error closing composite transport for path: {}", (Object)normalizedPath, e)).subscribe();
        }
        this.sharedServerMap.remove(normalizedPath);
        this.routeMap.entrySet().removeIf(entry -> ((String)entry.getKey()).startsWith(Objects.requireNonNull(normalizedPath)) || ((String)entry.getKey()).startsWith(uri));
        LOG.info("Removed MCP server for path: {}", (Object)normalizedPath);
    }

    public void addTool(String serverPath, String name, String description, String requestTemplate, String inputSchema) {
        String normalizedPath = this.normalizeServerPath(serverPath);
        try {
            this.removeTool(serverPath, name);
        }
        catch (Exception exception) {
            // empty catch block
        }
        ToolDefinition shenyuToolDefinition = ShenyuToolDefinition.builder().name(name).description(description).requestConfig(requestTemplate).inputSchema(inputSchema).build();
        LOG.debug("Adding tool to shared server - name: {}, description: {}, path: {}", new Object[]{name, description, normalizedPath});
        ShenyuToolCallback shenyuToolCallback = new ShenyuToolCallback(shenyuToolDefinition);
        McpAsyncServer sharedServer = this.sharedServerMap.get(normalizedPath);
        if (Objects.nonNull(sharedServer)) {
            try {
                for (McpServerFeatures.AsyncToolSpecification asyncToolSpecification : McpToolUtils.toAsyncToolSpecifications((ToolCallback[])new ToolCallback[]{shenyuToolCallback})) {
                    sharedServer.addTool(asyncToolSpecification).block();
                }
                Set<String> protocols = this.getSupportedProtocols(normalizedPath);
                LOG.info("Added tool '{}' to shared server for path: {} (available across protocols: {})", new Object[]{name, normalizedPath, protocols});
            }
            catch (Exception e) {
                LOG.error("Failed to add tool '{}' to shared server for path: {}", new Object[]{name, normalizedPath, e});
            }
        } else {
            LOG.warn("No shared server found for path: {}", (Object)normalizedPath);
        }
    }

    public void removeTool(String serverPath, String name) {
        String normalizedPath = this.normalizeServerPath(serverPath);
        LOG.debug("Removing tool from shared server - name: {}, path: {}", (Object)name, (Object)normalizedPath);
        McpAsyncServer sharedServer = this.sharedServerMap.get(normalizedPath);
        if (Objects.nonNull(sharedServer)) {
            try {
                sharedServer.removeTool(name).block();
                Set<String> protocols = this.getSupportedProtocols(normalizedPath);
                LOG.info("Removed tool '{}' from shared server for path: {} (removed from protocols: {})", new Object[]{name, normalizedPath, protocols});
            }
            catch (Exception e) {
                LOG.error("Failed to remove tool '{}' from shared server for path: {}", new Object[]{name, normalizedPath, e});
            }
        } else {
            LOG.warn("No shared server found for path: {}", (Object)normalizedPath);
        }
    }

    public Set<String> getSupportedProtocols(String serverPath) {
        String normalizedPath = this.normalizeServerPath(serverPath);
        CompositeTransportProvider compositeTransport = this.compositeTransportMap.get(normalizedPath);
        return Objects.nonNull(compositeTransport) ? compositeTransport.getSupportedProtocols() : new HashSet<String>();
    }

    private String normalizeServerPath(String path) {
        if (Objects.isNull(path)) {
            return null;
        }
        String normalizedPath = path;
        if (normalizedPath.endsWith("/streamablehttp")) {
            normalizedPath = normalizedPath.substring(0, normalizedPath.length() - "/streamablehttp".length());
            LOG.debug("Normalized Streamable HTTP path from '{}' to '{}' for shared server", (Object)path, (Object)normalizedPath);
        }
        return normalizedPath;
    }

    private static class CompositeTransportProvider
    implements McpServerTransportProvider {
        private final Map<String, Object> transports = new ConcurrentHashMap<String, Object>();
        private volatile McpServerSession.Factory sessionFactory;
        private final Map<String, Set<String>> protocolSessions = new ConcurrentHashMap<String, Set<String>>();

        private CompositeTransportProvider() {
        }

        public void addTransport(String protocol, Object transportProvider) {
            this.transports.put(protocol, transportProvider);
            this.protocolSessions.put(protocol, Collections.synchronizedSet(new HashSet()));
            if (Objects.nonNull(this.sessionFactory) && transportProvider instanceof McpServerTransportProvider) {
                ((McpServerTransportProvider)transportProvider).setSessionFactory(this.sessionFactory);
            }
            LOG.debug("Added transport '{}' to composite provider", (Object)protocol);
        }

        public ShenyuSseServerTransportProvider getSseTransport() {
            Object transport = this.transports.get(ShenyuMcpServerManager.SSE_PROTOCOL);
            return transport instanceof ShenyuSseServerTransportProvider ? (ShenyuSseServerTransportProvider)transport : null;
        }

        public ShenyuStreamableHttpServerTransportProvider getStreamableHttpTransport() {
            Object transport = this.transports.get(ShenyuMcpServerManager.STREAMABLE_HTTP_PROTOCOL);
            return transport instanceof ShenyuStreamableHttpServerTransportProvider ? (ShenyuStreamableHttpServerTransportProvider)transport : null;
        }

        public Object getTransport(String protocol) {
            return this.transports.get(protocol);
        }

        public boolean hasProtocol(String protocol) {
            return this.transports.containsKey(protocol);
        }

        public Set<String> getSupportedProtocols() {
            return new HashSet<String>(this.transports.keySet());
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public void setSessionFactory(McpServerSession.Factory sessionFactory) {
            this.sessionFactory = sessionFactory;
            Map<String, Object> map = this.transports;
            synchronized (map) {
                for (Object transport : this.transports.values()) {
                    if (!(transport instanceof McpServerTransportProvider)) continue;
                    try {
                        ((McpServerTransportProvider)transport).setSessionFactory(sessionFactory);
                    }
                    catch (Exception e) {
                        LOG.error("Failed to set session factory on transport: {}", (Object)transport.getClass().getSimpleName(), (Object)e);
                    }
                }
            }
            LOG.debug("Session factory set on composite transport with {} transports", (Object)this.transports.size());
        }

        public Mono<Void> notifyClients(String method, Object params) {
            if (this.transports.isEmpty()) {
                LOG.debug("No transports available for client notification");
                return Mono.empty();
            }
            LOG.debug("Broadcasting notification '{}' to {} transports", (Object)method, (Object)this.transports.size());
            return Flux.fromIterable(this.transports.entrySet()).flatMap(entry -> {
                String protocol = (String)entry.getKey();
                Object transport = entry.getValue();
                if (transport instanceof McpServerTransportProvider) {
                    return ((McpServerTransportProvider)transport).notifyClients(method, params).doOnSuccess(aVoid -> LOG.debug("Successfully notified {} clients", (Object)protocol)).doOnError(e -> LOG.warn("Failed to notify {} clients: {}", (Object)protocol, (Object)e.getMessage())).onErrorComplete();
                }
                LOG.warn("Transport '{}' does not implement McpServerTransportProvider", (Object)protocol);
                return Mono.empty();
            }).then().doOnSuccess(aVoid -> LOG.debug("Client notification broadcast completed"));
        }

        public Mono<Void> closeGracefully() {
            if (this.transports.isEmpty()) {
                return Mono.empty();
            }
            LOG.info("Initiating graceful shutdown of {} transports", (Object)this.transports.size());
            return Flux.fromIterable(this.transports.entrySet()).flatMap(entry -> {
                String protocol = (String)entry.getKey();
                Object transport = entry.getValue();
                if (transport instanceof McpServerTransportProvider) {
                    return ((McpServerTransportProvider)transport).closeGracefully().doOnSuccess(aVoid -> LOG.info("Successfully closed {} transport", (Object)protocol)).doOnError(e -> LOG.error("Error closing {} transport: {}", (Object)protocol, (Object)e.getMessage())).onErrorComplete();
                }
                LOG.warn("Transport '{}' does not implement graceful shutdown", (Object)protocol);
                return Mono.empty();
            }).then().doOnSuccess(aVoid -> {
                this.transports.clear();
                this.protocolSessions.clear();
                LOG.info("Graceful shutdown completed - all transports and sessions cleared");
            });
        }
    }
}

