/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.action.mcpserver;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableMap;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpServerSession;
import io.modelcontextprotocol.spec.McpServerTransport;
import io.modelcontextprotocol.spec.McpServerTransportProvider;
import io.modelcontextprotocol.util.Assert;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchException;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.common.lease.Releasable;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.http.HttpChunk;
import org.opensearch.ml.action.mcpserver.McpAsyncServerHolder;
import org.opensearch.ml.action.mcpserver.McpToolsHelper;
import org.opensearch.ml.common.MLIndex;
import org.opensearch.ml.engine.indices.MLIndicesHandler;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestChannel;
import org.opensearch.rest.RestResponse;
import org.opensearch.rest.StreamingRestChannel;
import org.opensearch.transport.client.node.NodeClient;
import reactor.core.Exceptions;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.MonoSink;

public class OpenSearchMcpServerTransportProvider
implements McpServerTransportProvider {
    @Generated
    private static final Logger log = LogManager.getLogger(OpenSearchMcpServerTransportProvider.class);
    public static final String ENDPOINT_EVENT_TYPE = "endpoint";
    private final ObjectMapper objectMapper;
    private McpServerSession.Factory sessionFactory;
    private final MLIndicesHandler mlIndicesHandler;
    private final McpToolsHelper mcpToolsHelper;
    private final Map<String, McpServerSession> sessions = new ConcurrentHashMap<String, McpServerSession>();

    public OpenSearchMcpServerTransportProvider(MLIndicesHandler mlIndicesHandler, McpToolsHelper mcpToolsHelper, ObjectMapper objectMapper) {
        Assert.notNull((Object)objectMapper, (String)"ObjectMapper must not be null");
        this.mlIndicesHandler = mlIndicesHandler;
        this.mcpToolsHelper = mcpToolsHelper;
        this.objectMapper = objectMapper;
    }

    public void setSessionFactory(McpServerSession.Factory sessionFactory) {
        this.sessionFactory = sessionFactory;
    }

    public Mono<Void> notifyClients(String method, Object params) {
        if (this.sessions.isEmpty()) {
            log.debug("No active sessions to broadcast message to");
            return Mono.empty();
        }
        log.debug("Attempting to broadcast message to {} active sessions", (Object)this.sessions.size());
        return Flux.fromStream(this.sessions.values().stream()).flatMap(session -> session.sendNotification(method, params).doOnError(e -> log.error("Failed to send message to session {}: {}", (Object)session.getId(), (Object)e.getMessage())).onErrorComplete()).then();
    }

    public Mono<Void> closeGracefully() {
        McpAsyncServerHolder.CHANNELS.clear();
        return Flux.fromIterable(this.sessions.values()).doFirst(() -> log.debug("Initiating graceful shutdown with {} active sessions", (Object)this.sessions.size())).flatMap(McpServerSession::closeGracefully).then();
    }

    public Mono<HttpChunk> handleSseConnection(StreamingRestChannel channel, boolean appendToBaseUrl, String nodeId, NodeClient client) {
        return Mono.create(sink -> {
            OpenSearchMcpSessionTransport sessionTransport = new OpenSearchMcpSessionTransport(channel);
            McpServerSession session = this.sessionFactory.create((McpServerTransport)sessionTransport);
            String sessionId = session.getId();
            ActionListener initIndexListener = ActionListener.wrap(created -> {
                if (created.booleanValue()) {
                    log.debug("Successfully created MCP session management index");
                    this.addSession(sessionId, session, appendToBaseUrl, nodeId, client, channel, (MonoSink<HttpChunk>)sink);
                } else {
                    log.debug("Failed to create MCP session management index for session: {}", (Object)sessionId);
                    sink.error((Throwable)new IllegalStateException(String.format(Locale.ROOT, "Failed to create MCP session management index for session: %s", sessionId)));
                }
            }, e -> {
                log.error("Failed to create session management index for session: {}", (Object)sessionId);
                sink.error((Throwable)new IllegalStateException("Failed to create session management index for session" + sessionId));
            });
            this.mlIndicesHandler.initMLMcpSessionManagementIndex(initIndexListener);
        });
    }

    private void addSession(String sessionId, McpServerSession session, boolean appendToBaseUrl, String nodeId, NodeClient client, StreamingRestChannel channel, MonoSink<HttpChunk> sink) {
        ActionListener actionListener = ActionListener.wrap(r -> {
            if (r != null && r.status() == RestStatus.CREATED) {
                this.reloadAllMcpTools(sessionId, session, appendToBaseUrl, channel, sink);
            } else {
                log.error("Failed to create new SSE connection for session: {}", (Object)sessionId);
                sink.error((Throwable)new IllegalStateException("Failed to create new SSE connection for session" + sessionId));
            }
        }, e -> {
            log.error("Failed to write sessionId into MCP session management index", (Throwable)e);
            sink.error((Throwable)e);
        });
        ImmutableMap source = ImmutableMap.of((Object)"node_id", (Object)nodeId, (Object)"status", (Object)"active", (Object)"create_time", (Object)Instant.now());
        IndexRequest indexRequest = new IndexRequest(MLIndex.MCP_SESSION_MANAGEMENT.getIndexName()).id(sessionId).source((Map)source);
        client.index(indexRequest, actionListener);
    }

    private void reloadAllMcpTools(String sessionId, McpServerSession session, boolean appendToBaseUrl, StreamingRestChannel channel, MonoSink<HttpChunk> sink) {
        if (this.sessions.isEmpty()) {
            ActionListener reloadMcpToolsListener = ActionListener.wrap(reloadResult -> {
                if (reloadResult.booleanValue()) {
                    this.initSessionInMemory(sessionId, session, appendToBaseUrl, channel, sink);
                }
            }, e -> {
                log.error("Failed to reload mcp tools", (Throwable)e);
                sink.error((Throwable)new OpenSearchException(String.format(Locale.ROOT, "Failed to create SSE connection because the target node MCP server failed to init tools with error: %s", e.getMessage()), new Object[0]));
            });
            this.mcpToolsHelper.autoLoadAllMcpTools((ActionListener<Boolean>)reloadMcpToolsListener);
        } else {
            this.initSessionInMemory(sessionId, session, appendToBaseUrl, channel, sink);
        }
    }

    private void initSessionInMemory(String sessionId, McpServerSession session, boolean appendToBaseUrl, StreamingRestChannel channel, MonoSink<HttpChunk> sink) {
        log.debug("Created new SSE connection for session: {}", (Object)sessionId);
        this.sessions.put(sessionId, session);
        log.debug("Sending initial endpoint event to session: {}", (Object)sessionId);
        String result = appendToBaseUrl ? String.format(Locale.ROOT, "/_plugins/_ml/mcp/sse/message?sessionId=%s", sessionId) : String.format(Locale.ROOT, "/sse/message?sessionId=%s", sessionId);
        McpAsyncServerHolder.CHANNELS.put(sessionId, channel);
        sink.success((Object)this.createHttpChunk(ENDPOINT_EVENT_TYPE, result));
    }

    public Mono<Void> handleMessage(String sessionId, String requestBody) {
        McpServerSession session = this.sessions.get(sessionId);
        if (session == null) {
            log.error("Session not found: {}", (Object)sessionId);
            return Mono.error((Throwable)new McpError((Object)"Session not found"));
        }
        return Mono.just((Object)requestBody).flatMap(body -> {
            try {
                McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage((ObjectMapper)this.objectMapper, (String)body);
                return session.handle(message);
            }
            catch (IOException | IllegalArgumentException e) {
                log.error("Failed to deserialize message: {}", (Object)e.getMessage());
                return Mono.error((Throwable)new McpError((Object)"Invalid message format"));
            }
        }).onErrorResume(Mono::error);
    }

    private HttpChunk createHttpChunk(String event, String jsonText) {
        String result = String.format(Locale.ROOT, "event: %s\ndata: %s\n\n", event, jsonText);
        final BytesReference content = BytesReference.fromByteBuffer((ByteBuffer)ByteBuffer.wrap(result.getBytes(StandardCharsets.UTF_8)));
        return new HttpChunk(){

            public void close() {
                if (content instanceof Releasable) {
                    ((Releasable)content).close();
                }
            }

            public boolean isLast() {
                return false;
            }

            public BytesReference content() {
                return content;
            }
        };
    }

    public class OpenSearchMcpSessionTransport
    implements McpServerTransport {
        public static final String MESSAGE_EVENT_TYPE = "message";
        private final StreamingRestChannel streamingRestChannel;

        public OpenSearchMcpSessionTransport(StreamingRestChannel streamingRestChannel) {
            this.streamingRestChannel = streamingRestChannel;
        }

        public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
            return Mono.fromSupplier(() -> this.writeValueAsString(message)).doOnNext(jsonText -> {
                HttpChunk event = OpenSearchMcpServerTransportProvider.this.createHttpChunk(MESSAGE_EVENT_TYPE, (String)jsonText);
                this.streamingRestChannel.sendChunk(event);
            }).doOnError(e -> {
                Throwable exception = Exceptions.unwrap((Throwable)e);
                try {
                    this.streamingRestChannel.sendResponse((RestResponse)new BytesRestResponse((RestChannel)this.streamingRestChannel, (Exception)new IllegalStateException(exception)));
                }
                catch (IOException ex) {
                    log.error("Failed to send error response during sending message", (Throwable)ex);
                }
            }).then();
        }

        public <T> T unmarshalFrom(Object data, TypeReference<T> typeRef) {
            return (T)OpenSearchMcpServerTransportProvider.this.objectMapper.convertValue(data, typeRef);
        }

        public Mono<Void> closeGracefully() {
            return Mono.empty();
        }

        private String writeValueAsString(McpSchema.JSONRPCMessage message) {
            try {
                return OpenSearchMcpServerTransportProvider.this.objectMapper.writeValueAsString((Object)message);
            }
            catch (JsonProcessingException e) {
                log.error("Failed to convert the JSONRPCMessage to raw String", (Throwable)e);
                throw new RuntimeException(e);
            }
        }
    }
}

