/*
 * Decompiled with CFR 0.152.
 */
package org.jkiss.dbeaver.model.ai.impl;

import java.lang.reflect.InvocationTargetException;
import java.util.List;
import java.util.concurrent.Flow;
import org.jkiss.code.NotNull;
import org.jkiss.code.Nullable;
import org.jkiss.dbeaver.DBException;
import org.jkiss.dbeaver.Log;
import org.jkiss.dbeaver.model.DBPDataSource;
import org.jkiss.dbeaver.model.ai.AIAssistant;
import org.jkiss.dbeaver.model.ai.AICommandRequest;
import org.jkiss.dbeaver.model.ai.AICommandResult;
import org.jkiss.dbeaver.model.ai.AIMessage;
import org.jkiss.dbeaver.model.ai.AIMessageType;
import org.jkiss.dbeaver.model.ai.AITextUtils;
import org.jkiss.dbeaver.model.ai.AITranslateRequest;
import org.jkiss.dbeaver.model.ai.engine.AIDatabaseContext;
import org.jkiss.dbeaver.model.ai.engine.AIEngine;
import org.jkiss.dbeaver.model.ai.engine.AIEngineRequest;
import org.jkiss.dbeaver.model.ai.engine.AIEngineResponse;
import org.jkiss.dbeaver.model.ai.engine.AIEngineResponseChunk;
import org.jkiss.dbeaver.model.ai.engine.TooManyRequestsException;
import org.jkiss.dbeaver.model.ai.impl.LogSubscriber;
import org.jkiss.dbeaver.model.ai.impl.MessageChunk;
import org.jkiss.dbeaver.model.ai.prompt.AIPromptBuilder;
import org.jkiss.dbeaver.model.ai.prompt.AIPromptFormatter;
import org.jkiss.dbeaver.model.ai.registry.AIEngineRegistry;
import org.jkiss.dbeaver.model.ai.registry.AIFormatterRegistry;
import org.jkiss.dbeaver.model.ai.registry.AISettingsRegistry;
import org.jkiss.dbeaver.model.ai.utils.AIUtils;
import org.jkiss.dbeaver.model.ai.utils.DatabaseMetadataUtils;
import org.jkiss.dbeaver.model.ai.utils.ThrowableSupplier;
import org.jkiss.dbeaver.model.app.DBPWorkspace;
import org.jkiss.dbeaver.model.exec.DBExecUtils;
import org.jkiss.dbeaver.model.runtime.DBRProgressMonitor;
import org.jkiss.dbeaver.model.sql.SQLUtils;
import org.jkiss.dbeaver.utils.RuntimeUtils;

public class AIAssistantImpl
implements AIAssistant {
    private static final Log log = Log.getLog(AIAssistantImpl.class);
    private static final int MANY_REQUESTS_RETRIES = 3;
    private static final int MANY_REQUESTS_TIMEOUT = 500;
    private final AISettingsRegistry settingsRegistry = AISettingsRegistry.getInstance();
    private final AIEngineRegistry engineRegistry = AIEngineRegistry.getInstance();
    private final AIFormatterRegistry formatterRegistry = AIFormatterRegistry.getInstance();

    @Override
    public void initialize(@NotNull DBPWorkspace workspace) {
    }

    @Override
    @NotNull
    public String translateTextToSql(@NotNull DBRProgressMonitor monitor, @NotNull AITranslateRequest request) throws DBException {
        AIEngine engine = request.engine() != null ? request.engine() : this.getActiveEngine();
        AIMessage userMessage = new AIMessage(AIMessageType.USER, request.text());
        String prompt = this.buildPrompt(monitor, engine, request.context()).addGoals("Translate natural language text to SQL.").addOutputFormats("Place any explanation or comments before the SQL code block.", "Provide the SQL query in a fenced Markdown code block.").build();
        List<AIMessage> chatMessages = List.of(AIMessage.systemMessage(prompt), userMessage);
        AIEngineRequest completionRequest = new AIEngineRequest(AIUtils.truncateMessages(true, chatMessages, engine.getMaxContextSize(monitor)));
        AIEngineResponse completionResponse = this.requestCompletion(engine, monitor, completionRequest);
        MessageChunk[] messageChunks = this.processAndSplitCompletion(monitor, request.context(), completionResponse.variants().get(0));
        return AITextUtils.convertToSQL(userMessage, messageChunks, request.context().getExecutionContext().getDataSource());
    }

    @Override
    @NotNull
    public AICommandResult command(@NotNull DBRProgressMonitor monitor, @NotNull AICommandRequest request) throws DBException {
        AIEngine engine = request.engine() != null ? request.engine() : this.getActiveEngine();
        String prompt = this.buildPrompt(monitor, engine, request.context()).addGoals("Translate natural language text to SQL.").addOutputFormats("Place any explanation or comments before the SQL code block.", "Provide the SQL query in a fenced Markdown code block.").build();
        List<AIMessage> chatMessages = List.of(AIMessage.systemMessage(prompt), AIMessage.userMessage(request.text()));
        AIEngineRequest completionRequest = new AIEngineRequest(AIUtils.truncateMessages(true, chatMessages, engine.getMaxContextSize(monitor)));
        AIEngineResponse completionResponse = this.requestCompletion(engine, monitor, completionRequest);
        MessageChunk[] messageChunks = this.processAndSplitCompletion(monitor, request.context(), completionResponse.variants().get(0));
        String finalSQL = null;
        StringBuilder messages = new StringBuilder();
        MessageChunk[] messageChunkArray = messageChunks;
        int n = messageChunks.length;
        int n2 = 0;
        while (n2 < n) {
            MessageChunk chunk = messageChunkArray[n2];
            if (chunk instanceof MessageChunk.Code) {
                MessageChunk.Code code = (MessageChunk.Code)chunk;
                finalSQL = code.text();
            } else if (chunk instanceof MessageChunk.Text) {
                MessageChunk.Text textChunk = (MessageChunk.Text)chunk;
                messages.append(textChunk.text());
            }
            ++n2;
        }
        return new AICommandResult(finalSQL, messages.toString());
    }

    @Override
    public boolean hasValidConfiguration() throws DBException {
        return this.getActiveEngine().hasValidConfiguration();
    }

    protected MessageChunk[] processAndSplitCompletion(@NotNull DBRProgressMonitor monitor, @NotNull AIDatabaseContext context, @NotNull String completion) throws DBException {
        String processedCompletion = AIUtils.processCompletion(monitor, context.getExecutionContext(), context.getScopeObject(), completion, this.formatter(), true);
        return AITextUtils.splitIntoChunks(SQLUtils.getDialectFromDataSource((DBPDataSource)context.getExecutionContext().getDataSource()), processedCompletion);
    }

    private static <T> T callWithRetry(ThrowableSupplier<T, DBException> supplier) throws DBException {
        int retry = 0;
        while (retry < 3) {
            try {
                return supplier.get();
            }
            catch (TooManyRequestsException tooManyRequestsException) {
                if (++retry >= 3) continue;
                log.debug((Object)"Too many engine requests. Retry after 500ms");
                RuntimeUtils.pause((int)500);
            }
        }
        throw new DBException("Request failed after 3 attempts");
    }

    protected AIEngine getActiveEngine() throws DBException {
        return this.engineRegistry.getCompletionEngine(this.settingsRegistry.getSettings().activeEngine());
    }

    protected AIEngineResponse requestCompletion(@NotNull AIEngine engine, @NotNull DBRProgressMonitor monitor, @NotNull AIEngineRequest request) throws DBException {
        try {
            if (engine.isLoggingEnabled()) {
                log.debug((Object)("Requesting completion [request=" + String.valueOf(request) + "]"));
            }
            AIEngineResponse completionResponse = AIAssistantImpl.callWithRetry(() -> engine.requestCompletion(monitor, request));
            if (engine.isLoggingEnabled()) {
                log.debug((Object)("Received completion [response=" + String.valueOf(completionResponse) + "]"));
            }
            return completionResponse;
        }
        catch (Exception e) {
            if (e instanceof DBException) {
                throw (DBException)((Object)e);
            }
            throw new DBException("Error requesting completion", (Throwable)e);
        }
    }

    protected Flow.Publisher<AIEngineResponseChunk> requestCompletionStream(@NotNull AIEngine engine, @NotNull DBRProgressMonitor monitor, @NotNull AIEngineRequest request) throws DBException {
        try {
            Flow.Publisher publisher = AIAssistantImpl.callWithRetry(() -> engine.requestCompletionStream(monitor, request));
            boolean loggingEnabled = engine.isLoggingEnabled();
            return subscriber -> {
                if (loggingEnabled) {
                    log.debug((Object)("Requesting completion stream [request=" + String.valueOf(request) + "]"));
                    publisher.subscribe(new LogSubscriber(log, subscriber));
                } else {
                    publisher.subscribe(subscriber);
                }
            };
        }
        catch (Exception e) {
            log.error((Object)"Error requesting completion stream", (Throwable)e);
            if (e instanceof DBException) {
                throw (DBException)((Object)e);
            }
            throw new DBException("Error requesting completion stream", (Throwable)e);
        }
    }

    protected AIPromptFormatter formatter() throws DBException {
        return this.formatterRegistry.getFormatter("core");
    }

    protected AIPromptBuilder buildPrompt(@NotNull DBRProgressMonitor monitor, @NotNull AIEngine engine, @Nullable AIDatabaseContext context) throws DBException {
        return this.buildPrompt(monitor, engine, this.formatter(), context);
    }

    protected AIPromptBuilder buildPrompt(@NotNull DBRProgressMonitor monitor, @NotNull AIEngine engine, @NotNull AIPromptFormatter formatter, @Nullable AIDatabaseContext context) throws DBException {
        AIPromptBuilder promptBuilder = AIPromptBuilder.createForDataSource(context != null ? context.getDataSource() : null, formatter);
        if (context != null) {
            DBExecUtils.tryExecuteRecover((Object)monitor, (DBPDataSource)context.getExecutionContext().getDataSource(), param -> {
                try {
                    this.describeDatabaseMetadata(monitor, engine, formatter, context, promptBuilder);
                }
                catch (DBException e) {
                    throw new InvocationTargetException(e);
                }
            });
        } else {
            this.describeDatabaseMetadata(monitor, engine, formatter, context, promptBuilder);
        }
        return promptBuilder;
    }

    protected void describeDatabaseMetadata(@NotNull DBRProgressMonitor monitor, @NotNull AIEngine engine, @Nullable AIDatabaseContext context, @NotNull AIPromptBuilder promptBuilder) throws DBException {
        this.describeDatabaseMetadata(monitor, engine, this.formatter(), context, promptBuilder);
    }

    protected void describeDatabaseMetadata(@NotNull DBRProgressMonitor monitor, @NotNull AIEngine engine, @NotNull AIPromptFormatter formatter, @Nullable AIDatabaseContext context, @NotNull AIPromptBuilder promptBuilder) throws DBException {
        if (context != null) {
            String description = DatabaseMetadataUtils.describeContext(monitor, context, formatter, AIUtils.getMaxRequestTokens(engine, monitor));
            promptBuilder.addDatabaseSnapshot(description);
        }
    }
}

