/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.dubbo.rpc.rocketmq;

import org.apache.dubbo.common.URL;
import org.apache.dubbo.common.constants.CommonConstants;
import org.apache.dubbo.remoting.Channel;
import org.apache.dubbo.remoting.buffer.ChannelBuffer;
import org.apache.dubbo.remoting.buffer.DynamicChannelBuffer;
import org.apache.dubbo.remoting.buffer.HeapChannelBuffer;
import org.apache.dubbo.remoting.exchange.Request;
import org.apache.dubbo.remoting.exchange.Response;
import org.apache.dubbo.rpc.Exporter;
import org.apache.dubbo.rpc.Invocation;
import org.apache.dubbo.rpc.Invoker;
import org.apache.dubbo.rpc.Protocol;
import org.apache.dubbo.rpc.ProtocolServer;
import org.apache.dubbo.rpc.Result;
import org.apache.dubbo.rpc.RpcContext;
import org.apache.dubbo.rpc.RpcException;
import org.apache.dubbo.rpc.model.FrameworkModel;
import org.apache.dubbo.rpc.model.ScopeModel;
import org.apache.dubbo.rpc.protocol.AbstractProtocol;
import org.apache.dubbo.rpc.rocketmq.codec.RocketMQCountCodec;
import org.apache.rocketmq.client.common.ClientErrorCode;
import org.apache.rocketmq.client.consumer.MessageSelector;
import org.apache.rocketmq.client.consumer.listener.ConsumeConcurrentlyContext;
import org.apache.rocketmq.client.consumer.listener.ConsumeConcurrentlyStatus;
import org.apache.rocketmq.client.consumer.listener.MessageListenerConcurrently;
import org.apache.rocketmq.client.exception.MQClientException;
import org.apache.rocketmq.client.producer.DefaultMQProducer;
import org.apache.rocketmq.client.producer.SendResult;
import org.apache.rocketmq.client.utils.MessageUtil;
import org.apache.rocketmq.common.message.Message;
import org.apache.rocketmq.common.message.MessageConst;
import org.apache.rocketmq.common.message.MessageExt;

import java.io.IOException;
import java.util.List;
import java.util.Objects;


public class RocketMQProtocol extends AbstractProtocol {

    public static final String NAME = "rocketmq";

    public static final int DEFAULT_PORT = 20880;


    public RocketMQProtocol() {
    }


    public static RocketMQProtocol getDubboProtocol(ScopeModel scopeModel) {
        return (RocketMQProtocol) scopeModel.getExtensionLoader(Protocol.class).getExtension(RocketMQProtocol.NAME, false);
    }

    /**
     * <host:port,Exchanger>
     */

    @Override
    public int getDefaultPort() {
        return 9876;
    }

    @Override
    public <T> Exporter<T> export(Invoker<T> invoker) throws RpcException {
        URL url = invoker.getUrl();
        RocketMQExporter<T> exporter = new RocketMQExporter<T>(invoker, url, exporterMap);

        String topic = exporter.getKey();
        RocketMQProtocolServer rocketMQProtocolServer;
        try {
            rocketMQProtocolServer = this.openServer(url, CommonConstants.PROVIDER);
        } catch (Exception e) {
            String exeptionInfo = String.format("create rocketmq client fail, url is %s , topic is %s, cause is %s", url, topic, e.getMessage());
            logger.error(exeptionInfo, e);
            throw new RpcException(exeptionInfo, e);
        }
        try {
            String groupModel = url.getParameter("groupModel");
            if (Objects.nonNull(groupModel) && Objects.equals(groupModel, "select")) {
                rocketMQProtocolServer.getDefaultMQPushConsumer().subscribe(topic, this.createMessageSelector(url));
            } else {
                rocketMQProtocolServer.getDefaultMQPushConsumer().subscribe(topic, CommonConstants.ANY_VALUE);
            }
            return exporter;
        } catch (Exception e) {
            String exeptionInfo = String.format("topic subscirbe fail, topic is %s, cause is %s", topic, e.getMessage());
            logger.error(exeptionInfo, e);
            throw new RpcException(exeptionInfo, e);
        }
    }

    private MessageSelector createMessageSelector(URL url) {
        if (Objects.isNull(url.getParameter(CommonConstants.GROUP_KEY)) &&
            Objects.isNull((url.getParameter(CommonConstants.VERSION_KEY)))) {
            throw new RuntimeException("group and version is not null");
        }
        StringBuffer stringBuffer = new StringBuffer();
        boolean isGroup = false;
        if (Objects.nonNull(url.getParameter(CommonConstants.GROUP_KEY))) {
            stringBuffer.append(CommonConstants.GROUP_KEY).append("=").append(url.getParameter(CommonConstants.GROUP_KEY));
            isGroup = true;
        }
        if (Objects.nonNull(url.getParameter(CommonConstants.VERSION_KEY))) {
            if (isGroup) {
                stringBuffer.append(" and ");
            }
            stringBuffer.append(CommonConstants.VERSION_KEY).append("=").append(url.getParameter(CommonConstants.VERSION_KEY));
        }
        return MessageSelector.bySql(stringBuffer.toString());
    }

    private RocketMQProtocolServer openServer(URL url, String model) {
        // find server.
        String key = url.getAddress();
        ProtocolServer server = serverMap.get(key);
        if (server == null) {
            synchronized (this) {
                server = serverMap.get(key);
                if (server == null) {
                    serverMap.put(key, createServer(url, key, model));
                }
                server = serverMap.get(key);

                RocketMQProtocolServer rocketMQProtocolServer = (RocketMQProtocolServer) server;
                return rocketMQProtocolServer;
            }
        } else {
            return (RocketMQProtocolServer) server;
        }
    }

    private ProtocolServer createServer(URL url, String key, String model) {
        RocketMQProtocolServer rocketMQProtocolServer = new RocketMQProtocolServer();
        rocketMQProtocolServer.setModel(model);
        DubboMessageListenerConcurrently dubboMessageListenerConcurrently = new DubboMessageListenerConcurrently();
        dubboMessageListenerConcurrently.defaultMQProducer = rocketMQProtocolServer.getDefaultMQProducer();
        dubboMessageListenerConcurrently.rocketMQProtocolServer = rocketMQProtocolServer;
        rocketMQProtocolServer.setMessageListenerConcurrently(dubboMessageListenerConcurrently);
        rocketMQProtocolServer.reset(url);

        return rocketMQProtocolServer;
    }

    @Override
    protected <T> Invoker<T> protocolBindingRefer(Class<T> type, URL url) throws RpcException {
        try {
            RocketMQProtocolServer rocketMQProtocolServer = this.openServer(url, CommonConstants.CONSUMER);
            RocketMQInvoker<T> rocketMQInvoker = new RocketMQInvoker<>(type, url, rocketMQProtocolServer);
            return rocketMQInvoker;
        } catch (Exception e) {
            String exceptionInfo = String.format("protocol binding refer fail, url is %s , cause is %s ", url, e.getMessage());
            logger.error(exceptionInfo, e);
            throw new RpcException(exceptionInfo, e);
        }
    }

    private class DubboMessageListenerConcurrently implements MessageListenerConcurrently {

        private RocketMQCountCodec rocketmqCountCodec = new RocketMQCountCodec(FrameworkModel.defaultModel());

        private DefaultMQProducer defaultMQProducer;

        private RocketMQProtocolServer rocketMQProtocolServer;



        @SuppressWarnings("deprecation")
        @Override
        public ConsumeConcurrentlyStatus consumeMessage(List<MessageExt> msgs, ConsumeConcurrentlyContext context) {
            for (MessageExt messageExt : msgs) {
                rocketMQProtocolServer.getExecutorService().submit(new Runnable() {
                    @Override
                    public void run() {
                        execute(messageExt);
                    }
                });
            }
            return ConsumeConcurrentlyStatus.CONSUME_SUCCESS;
        }

        private void execute(MessageExt messageExt){
            RpcContext.getContext().setRemoteAddress(messageExt.getUserProperty(RocketMQProtocolConstant.SEND_ADDRESS), 9876);
            String urlString = messageExt.getUserProperty(RocketMQProtocolConstant.URL_STRING);
            URL url = URL.valueOf(urlString);

            RocketMQChannel channel = new RocketMQChannel();
            channel.setRemoteAddress(RpcContext.getContext().getRemoteAddress());
            channel.setUrl(url);
            channel.setUrlString(urlString);
            channel.setMessageExt(messageExt);
            channel.setDefaultMQProducer(defaultMQProducer);
            channel.setRocketMQCountCodec(rocketmqCountCodec);

            Response response = this.invoke(messageExt, channel, url);
            if (Objects.isNull(response)) {
                return;
            }
            ChannelBuffer buffer = this.createChannelBuffer(channel, response, url);
            if (Objects.isNull(buffer)) {
                return;
            }
            this.sendMessage(messageExt, buffer, url, urlString);
        }

        private Response invoke(MessageExt messageExt, Channel channel, URL url) {
            Response response = new Response();
            try {
                String timeoutString = messageExt.getUserProperty(CommonConstants.TIMEOUT_KEY);
                Long timeout = Long.valueOf(timeoutString);
                if (logger.isDebugEnabled()) {
                    logger.debug(String.format("reply message ext is : %s", messageExt));
                }
                if (Objects.isNull(messageExt.getProperty(MessageConst.PROPERTY_CLUSTER))) {
                    MQClientException exception = new MQClientException(ClientErrorCode.CREATE_REPLY_MESSAGE_EXCEPTION,
                        "create reply message fail, requestMessage error, property[" + MessageConst.PROPERTY_CLUSTER + "] is null.");
                    response.setErrorMessage(exception.getMessage());
                    response.setStatus(Response.BAD_REQUEST);
                    logger.error(exception);
                } else {
                    HeapChannelBuffer heapChannelBuffer = new HeapChannelBuffer(messageExt.getBody());
                    Object object = rocketmqCountCodec.decode(channel, heapChannelBuffer);
                    String topic = messageExt.getTopic();
                    Invocation inv = (Invocation) ((Request) object).getData();
                    if (timeout < System.currentTimeMillis()) {
                        logger.warn(String.format("message timeoute time is %d invocation is %s ", timeout, inv));
                        return null;
                    }
                    Invoker<?> invoker = exporterMap.get(topic).getInvoker();

                    RpcContext.getContext().setRemoteAddress(channel.getRemoteAddress());
                    Result result = invoker.invoke(inv);
                    response.setStatus(Response.OK);
                    response.setResult(result);
                }
            } catch (Exception e) {
                String exceptionInfo = String.format("data decode or invoke fail, url is %s cause is %s", url, e.getMessage());
                response.setErrorMessage(exceptionInfo);
                response.setStatus(Response.BAD_REQUEST);
                logger.error(exceptionInfo, e);
            }
            return response;
        }

        private ChannelBuffer createChannelBuffer(Channel channel, Response response, URL url) {
            ChannelBuffer buffer = new DynamicChannelBuffer(2048);
            try {
                rocketmqCountCodec.encode(channel, buffer, response);
            } catch (Exception e) {
                String exceptionInfo = String.format("encode fail, url is %s cause is %s", url, e.getMessage());
                response.setErrorMessage(exceptionInfo);
                response.setStatus(Response.BAD_REQUEST);
                logger.error(exceptionInfo, e);
                try {
                    buffer = new DynamicChannelBuffer(2048);
                    rocketmqCountCodec.encode(channel, buffer, response);
                } catch (IOException e1) {
                    String exceptionInfo1 = String.format("encode exception response fail, url is %s cause is %s", url, e.getMessage());
                    logger.error(exceptionInfo1, e1);
                    buffer = null;
                }
            }
            return buffer;
        }

        private boolean sendMessage(MessageExt messageExt, ChannelBuffer buffer, URL url, String urlString) {
            try {
                Message newMessage = MessageUtil.createReplyMessage(messageExt, buffer.array());
                newMessage.putUserProperty(RocketMQProtocolConstant.SEND_ADDRESS, RocketMQProtocolConstant.LOCAL_ADDRESS.getHostString());
                newMessage.putUserProperty(RocketMQProtocolConstant.URL_STRING, urlString);
                SendResult sendResult = defaultMQProducer.send(newMessage, 3000);
                if (logger.isDebugEnabled()) {
                    logger.debug(String.format("send result is : %s", sendResult));
                }
                return true;
            } catch (Exception e) {
                String exceptionInfo = String.format("send response fail, url is %s cause is %s", url, e.getMessage());
                logger.error(exceptionInfo, e);
                return false;
            }
        }

    }

}
