/*
 * Decompiled with CFR 0.152.
 */
package org.apache.shenyu.plugin.ai.transformer.request;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.StringReader;
import java.io.UnsupportedEncodingException;
import java.net.URLEncoder;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Stream;
import org.apache.shenyu.common.dto.RuleData;
import org.apache.shenyu.common.dto.SelectorData;
import org.apache.shenyu.common.dto.convert.plugin.AiRequestTransformerConfig;
import org.apache.shenyu.common.dto.convert.rule.AiRequestTransformerHandle;
import org.apache.shenyu.common.enums.AiModelProviderEnum;
import org.apache.shenyu.common.enums.PluginEnum;
import org.apache.shenyu.common.utils.GsonUtils;
import org.apache.shenyu.common.utils.Singleton;
import org.apache.shenyu.plugin.ai.common.spring.ai.registry.AiModelFactoryRegistry;
import org.apache.shenyu.plugin.ai.transformer.request.cache.ChatClientCache;
import org.apache.shenyu.plugin.ai.transformer.request.handler.AiRequestTransformerPluginHandler;
import org.apache.shenyu.plugin.ai.transformer.request.template.AiRequestTransformerTemplate;
import org.apache.shenyu.plugin.api.ShenyuPluginChain;
import org.apache.shenyu.plugin.base.AbstractShenyuPlugin;
import org.apache.shenyu.plugin.base.utils.CacheKeyUtils;
import org.apache.shenyu.plugin.base.utils.ServerWebExchangeUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.codec.HttpMessageReader;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;

public class AiRequestTransformerPlugin
extends AbstractShenyuPlugin {
    private static final Logger LOG = LoggerFactory.getLogger(AiRequestTransformerPlugin.class);
    private final List<HttpMessageReader<?>> messageReaders;
    private final AiModelFactoryRegistry aiModelFactoryRegistry;

    public AiRequestTransformerPlugin(List<HttpMessageReader<?>> messageReaders, AiModelFactoryRegistry aiModelFactoryRegistry) {
        this.messageReaders = messageReaders;
        this.aiModelFactoryRegistry = aiModelFactoryRegistry;
    }

    protected Mono<Void> doExecute(ServerWebExchange exchange, ShenyuPluginChain chain, SelectorData selector, RuleData rule) {
        String provider;
        String apiKey;
        String baseUrl;
        AiRequestTransformerConfig aiRequestTransformerConfig = (AiRequestTransformerConfig)Singleton.INST.get(AiRequestTransformerConfig.class);
        if (Objects.isNull(aiRequestTransformerConfig)) {
            aiRequestTransformerConfig = new AiRequestTransformerConfig();
        }
        AiRequestTransformerHandle aiRequestTransformerHandle = (AiRequestTransformerHandle)AiRequestTransformerPluginHandler.CACHED_HANDLE.get().obtainHandle((Object)CacheKeyUtils.INST.getKey(rule));
        ChatClient client = ChatClientCache.getInstance().getClient("default");
        if (Objects.nonNull(aiRequestTransformerHandle)) {
            aiRequestTransformerConfig = new AiRequestTransformerConfig();
            Optional.ofNullable(aiRequestTransformerHandle.getProvider()).ifPresent(arg_0 -> ((AiRequestTransformerConfig)aiRequestTransformerConfig).setProvider(arg_0));
            Optional.ofNullable(aiRequestTransformerHandle.getBaseUrl()).ifPresent(arg_0 -> ((AiRequestTransformerConfig)aiRequestTransformerConfig).setBaseUrl(arg_0));
            Optional.ofNullable(aiRequestTransformerHandle.getApiKey()).ifPresent(arg_0 -> ((AiRequestTransformerConfig)aiRequestTransformerConfig).setApiKey(arg_0));
            Optional.ofNullable(aiRequestTransformerHandle.getModel()).ifPresent(arg_0 -> ((AiRequestTransformerConfig)aiRequestTransformerConfig).setModel(arg_0));
            Optional.ofNullable(aiRequestTransformerHandle.getContent()).ifPresent(arg_0 -> ((AiRequestTransformerConfig)aiRequestTransformerConfig).setContent(arg_0));
            client = ChatClientCache.getInstance().getClient(rule.getId());
        }
        if (Stream.of(baseUrl = aiRequestTransformerConfig.getBaseUrl(), apiKey = aiRequestTransformerConfig.getApiKey(), provider = aiRequestTransformerConfig.getProvider()).anyMatch(Objects::isNull)) {
            Object missing = "";
            missing = (String)missing + (Objects.isNull(baseUrl) ? "baseUrl, " : "");
            missing = (String)missing + (Objects.isNull(apiKey) ? "apiKey, " : "");
            missing = (String)missing + (Objects.isNull(provider) ? "provider, " : "");
            LOG.error("Missing configurations: {}", (Object)((String)missing).substring(0, ((String)missing).length() - 2));
            return chain.execute(exchange);
        }
        if (Objects.isNull(client)) {
            ChatModel aiModel = this.aiModelFactoryRegistry.getFactory(AiModelProviderEnum.getByName((String)provider)).createAiModel(AiRequestTransformerPluginHandler.convertConfig(aiRequestTransformerConfig));
            client = ChatClientCache.getInstance().init(rule.getId(), aiModel);
        }
        AiRequestTransformerTemplate aiRequestTransformerTemplate = new AiRequestTransformerTemplate(aiRequestTransformerConfig.getContent(), exchange.getRequest());
        ChatClient finalClient = client;
        return aiRequestTransformerTemplate.assembleMessage().flatMap(message -> Mono.fromCallable(() -> finalClient.prompt().user(message).call().content()).subscribeOn(Schedulers.boundedElastic()).flatMap(aiResponse -> AiRequestTransformerPlugin.convertHeader(exchange, aiResponse).flatMap(serverWebExchange -> AiRequestTransformerPlugin.convertBody(serverWebExchange, this.messageReaders, aiResponse)).flatMap(arg_0 -> ((ShenyuPluginChain)chain).execute(arg_0))));
    }

    private static Mono<ServerWebExchange> convertBody(ServerWebExchange exchange, List<HttpMessageReader<?>> readers, String aiResponse) {
        MediaType mediaType = exchange.getRequest().getHeaders().getContentType();
        if (MediaType.APPLICATION_JSON.isCompatibleWith(mediaType)) {
            return ServerWebExchangeUtils.rewriteRequestBody((ServerWebExchange)exchange, readers, requestBodyString -> Mono.just((Object)AiRequestTransformerPlugin.convertBodyJson(aiResponse)));
        }
        if (MediaType.APPLICATION_FORM_URLENCODED.isCompatibleWith(mediaType)) {
            return ServerWebExchangeUtils.rewriteRequestBody((ServerWebExchange)exchange, readers, requestBodyString -> Mono.just((Object)AiRequestTransformerPlugin.convertBodyFormData(aiResponse)));
        }
        return Mono.just((Object)exchange);
    }

    static String convertBodyJson(String aiResponse) {
        return AiRequestTransformerPlugin.extractJsonBodyFromHttpResponse(aiResponse);
    }

    static String convertBodyFormData(String aiResponse) {
        Map formDataMap = GsonUtils.getInstance().toObjectMap(AiRequestTransformerPlugin.extractJsonBodyFromHttpResponse(aiResponse));
        return AiRequestTransformerPlugin.mapToFormUrlEncoded(formDataMap);
    }

    public static String extractJsonBodyFromHttpResponse(String aiResponse) {
        if (Objects.isNull(aiResponse) || aiResponse.isEmpty()) {
            return null;
        }
        String[] lines = aiResponse.split("\\R");
        int emptyLineIndex = -1;
        for (int i = 0; i < lines.length; ++i) {
            if (!lines[i].trim().isEmpty()) continue;
            emptyLineIndex = i;
            break;
        }
        if (emptyLineIndex == -1 || emptyLineIndex == lines.length - 1) {
            return null;
        }
        StringBuilder bodyBuilder = new StringBuilder();
        for (int i = emptyLineIndex + 1; i < lines.length; ++i) {
            bodyBuilder.append(lines[i]);
            bodyBuilder.append("\n");
        }
        String body = bodyBuilder.toString().trim();
        if (body.startsWith("{") && body.endsWith("}") || body.startsWith("[") && body.endsWith("]")) {
            Map requestBodyMap = GsonUtils.getInstance().convertToMap(body);
            return GsonUtils.getInstance().toJson((Object)requestBodyMap);
        }
        return null;
    }

    public static String mapToFormUrlEncoded(Map<String, Object> map) {
        StringBuilder sb = new StringBuilder();
        try {
            for (Map.Entry<String, Object> entry : map.entrySet()) {
                if (sb.length() > 0) {
                    sb.append("&");
                }
                sb.append(URLEncoder.encode(entry.getKey(), "UTF-8"));
                sb.append("=");
                sb.append(URLEncoder.encode(String.valueOf(entry.getValue()), "UTF-8"));
            }
        }
        catch (UnsupportedEncodingException e) {
            throw new RuntimeException(e);
        }
        return sb.toString();
    }

    private static Mono<ServerWebExchange> convertHeader(ServerWebExchange exchange, String aiResponse) {
        HttpHeaders newHeaders = AiRequestTransformerPlugin.extractHeadersFromAiResponse(aiResponse);
        exchange.getRequest().mutate().headers(httpHeaders -> {
            httpHeaders.clear();
            httpHeaders.putAll((Map)newHeaders);
        }).build();
        return Mono.just((Object)exchange);
    }

    static HttpHeaders extractHeadersFromAiResponse(String aiResponse) {
        HttpHeaders headers = new HttpHeaders();
        if (Objects.isNull(aiResponse) || aiResponse.isEmpty()) {
            return headers;
        }
        try (BufferedReader reader = new BufferedReader(new StringReader(aiResponse));){
            String line;
            boolean headerSectionStarted = false;
            while (Objects.nonNull(line = reader.readLine())) {
                if (!headerSectionStarted) {
                    if (!line.startsWith("HTTP/1.1") && !line.matches("^(GET|POST|PUT|DELETE|PATCH|OPTIONS|HEAD)\\s.*\\sHTTP/1.1$")) continue;
                    headerSectionStarted = true;
                    continue;
                }
                if (line.trim().isEmpty()) {
                    break;
                }
                int colonIndex = line.indexOf(":");
                if (colonIndex <= 0) continue;
                String name = line.substring(0, colonIndex).trim();
                String value = line.substring(colonIndex + 1).trim();
                headers.add(name, value);
            }
        }
        catch (IOException e) {
            LOG.error("AI request transformer plugin: extract headers from AiResponse fail");
        }
        return headers;
    }

    public int getOrder() {
        return PluginEnum.AI_REQUEST_TRANSFORMER.getCode();
    }

    public String named() {
        return PluginEnum.AI_REQUEST_TRANSFORMER.getName();
    }
}

