/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ratis.netty.client;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.function.Supplier;
import org.apache.ratis.client.DataStreamClientRpc;
import org.apache.ratis.client.RaftClientConfigKeys;
import org.apache.ratis.conf.RaftProperties;
import org.apache.ratis.datastream.impl.DataStreamPacketByteBuffer;
import org.apache.ratis.datastream.impl.DataStreamRequestByteBuf;
import org.apache.ratis.datastream.impl.DataStreamRequestByteBuffer;
import org.apache.ratis.datastream.impl.DataStreamRequestFilePositionCount;
import org.apache.ratis.io.StandardWriteOption;
import org.apache.ratis.io.WriteOption;
import org.apache.ratis.netty.NettyConfigKeys;
import org.apache.ratis.netty.NettyDataStreamUtils;
import org.apache.ratis.netty.NettyUtils;
import org.apache.ratis.netty.client.NettyClientReplies;
import org.apache.ratis.protocol.ClientInvocationId;
import org.apache.ratis.protocol.DataStreamReply;
import org.apache.ratis.protocol.DataStreamRequest;
import org.apache.ratis.protocol.RaftPeer;
import org.apache.ratis.protocol.exceptions.AlreadyClosedException;
import org.apache.ratis.protocol.exceptions.TimeoutIOException;
import org.apache.ratis.security.TlsConf;
import org.apache.ratis.thirdparty.io.netty.bootstrap.Bootstrap;
import org.apache.ratis.thirdparty.io.netty.buffer.ByteBuf;
import org.apache.ratis.thirdparty.io.netty.channel.Channel;
import org.apache.ratis.thirdparty.io.netty.channel.ChannelFuture;
import org.apache.ratis.thirdparty.io.netty.channel.ChannelFutureListener;
import org.apache.ratis.thirdparty.io.netty.channel.ChannelHandler;
import org.apache.ratis.thirdparty.io.netty.channel.ChannelHandlerContext;
import org.apache.ratis.thirdparty.io.netty.channel.ChannelInboundHandler;
import org.apache.ratis.thirdparty.io.netty.channel.ChannelInboundHandlerAdapter;
import org.apache.ratis.thirdparty.io.netty.channel.ChannelInitializer;
import org.apache.ratis.thirdparty.io.netty.channel.ChannelOption;
import org.apache.ratis.thirdparty.io.netty.channel.ChannelPipeline;
import org.apache.ratis.thirdparty.io.netty.channel.EventLoopGroup;
import org.apache.ratis.thirdparty.io.netty.channel.socket.SocketChannel;
import org.apache.ratis.thirdparty.io.netty.handler.codec.ByteToMessageDecoder;
import org.apache.ratis.thirdparty.io.netty.handler.codec.MessageToMessageEncoder;
import org.apache.ratis.thirdparty.io.netty.handler.ssl.SslContext;
import org.apache.ratis.thirdparty.io.netty.util.concurrent.Future;
import org.apache.ratis.thirdparty.io.netty.util.concurrent.GenericFutureListener;
import org.apache.ratis.util.JavaUtils;
import org.apache.ratis.util.MemoizedSupplier;
import org.apache.ratis.util.NetUtils;
import org.apache.ratis.util.Preconditions;
import org.apache.ratis.util.ReferenceCountedObject;
import org.apache.ratis.util.SizeInBytes;
import org.apache.ratis.util.TimeDuration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class NettyClientStreamRpc
implements DataStreamClientRpc {
    public static final Logger LOG = LoggerFactory.getLogger(NettyClientStreamRpc.class);
    private final String name;
    private final Connection connection;
    private final NettyClientReplies replies = new NettyClientReplies();
    private final TimeDuration requestTimeout;
    private final TimeDuration closeTimeout;
    private final int flushRequestCountMin;
    private final SizeInBytes flushRequestBytesMin;
    private final OutstandingRequests outstandingRequests = new OutstandingRequests();
    static final MessageToMessageEncoder<DataStreamRequestByteBuffer> ENCODER = new Encoder();
    static final MessageToMessageEncoder<DataStreamRequestByteBuf> ENCODER_BYTE_BUF = new EncoderByteBuf();
    static final MessageToMessageEncoder<DataStreamRequestFilePositionCount> ENCODER_FILE_POSITION_COUNT = new EncoderFilePositionCount();
    static final MessageToMessageEncoder<ByteBuffer> ENCODER_BYTE_BUFFER = new EncoderByteBuffer();

    public NettyClientStreamRpc(RaftPeer server, TlsConf tlsConf, RaftProperties properties) {
        this.name = JavaUtils.getClassSimpleName(this.getClass()) + "->" + server.getId();
        this.requestTimeout = RaftClientConfigKeys.DataStream.requestTimeout(properties);
        this.closeTimeout = this.requestTimeout.multiply(2.0);
        this.flushRequestCountMin = RaftClientConfigKeys.DataStream.flushRequestCountMin(properties);
        this.flushRequestBytesMin = RaftClientConfigKeys.DataStream.flushRequestBytesMin(properties);
        InetSocketAddress address = NetUtils.createSocketAddr(server.getDataStreamAddress());
        SslContext sslContext = NettyUtils.buildSslContextForClient(tlsConf);
        this.connection = new Connection(address, WorkerGroupGetter.newInstance(properties), () -> NettyClientStreamRpc.newChannelInitializer(address, sslContext, this.getClientHandler()));
    }

    private ChannelInboundHandler getClientHandler() {
        return new ChannelInboundHandlerAdapter(){

            @Override
            public void channelRead(ChannelHandlerContext ctx, Object msg) {
                if (!(msg instanceof DataStreamReply)) {
                    LOG.error("{}: unexpected message {}", (Object)NettyClientStreamRpc.this.name, (Object)msg.getClass());
                    return;
                }
                DataStreamReply reply = (DataStreamReply)msg;
                LOG.debug("{}: read {}", (Object)NettyClientStreamRpc.this.name, (Object)reply);
                ClientInvocationId clientInvocationId = ClientInvocationId.valueOf(reply.getClientId(), reply.getStreamId());
                NettyClientReplies.ReplyMap replyMap = NettyClientStreamRpc.this.replies.getReplyMap(clientInvocationId);
                if (replyMap == null) {
                    LOG.error("{}: {} replyMap not found for reply: {}", NettyClientStreamRpc.this.name, clientInvocationId, reply);
                    return;
                }
                try {
                    replyMap.receiveReply(reply);
                }
                catch (Throwable cause) {
                    LOG.warn("{} : channelRead error:", (Object)NettyClientStreamRpc.this.name, (Object)cause);
                    replyMap.completeExceptionally(cause);
                }
            }

            @Override
            public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
                LOG.warn("{} : exceptionCaught", (Object)NettyClientStreamRpc.this.name, (Object)cause);
                ctx.close();
            }

            @Override
            public void channelInactive(ChannelHandlerContext ctx) {
                NettyClientStreamRpc.this.connection.scheduleReconnect("channel is inactive", null);
            }
        };
    }

    static ChannelInitializer<SocketChannel> newChannelInitializer(final InetSocketAddress address, final SslContext sslContext, final ChannelInboundHandler handler) {
        return new ChannelInitializer<SocketChannel>(){

            @Override
            public void initChannel(SocketChannel ch) {
                ChannelPipeline p = ch.pipeline();
                if (sslContext != null) {
                    p.addLast("ssl", (ChannelHandler)sslContext.newHandler(ch.alloc(), address.getHostName(), address.getPort()));
                }
                p.addLast(ENCODER);
                p.addLast(ENCODER_FILE_POSITION_COUNT);
                p.addLast(ENCODER_BYTE_BUFFER);
                p.addLast(ENCODER_BYTE_BUF);
                p.addLast(NettyClientStreamRpc.newDecoder());
                p.addLast(handler);
            }
        };
    }

    static ByteToMessageDecoder newDecoder() {
        return new ByteToMessageDecoder(){
            {
                this.setCumulator(ByteToMessageDecoder.COMPOSITE_CUMULATOR);
            }

            @Override
            protected void decode(ChannelHandlerContext context, ByteBuf buf, List<Object> out) {
                Optional.ofNullable(NettyDataStreamUtils.decodeDataStreamReplyByteBuffer(buf)).ifPresent(out::add);
            }
        };
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public CompletableFuture<DataStreamReply> streamAsync(DataStreamRequest request) {
        ChannelFuture channelFuture;
        NettyClientReplies.ReplyEntry replyEntry;
        Channel channel;
        CompletableFuture<DataStreamReply> f = new CompletableFuture<DataStreamReply>();
        ClientInvocationId clientInvocationId = ClientInvocationId.valueOf(request.getClientId(), request.getStreamId());
        boolean isClose = request.getWriteOptionList().contains(StandardWriteOption.CLOSE);
        NettyClientReplies.ReplyMap replyMap = this.replies.getReplyMap(clientInvocationId);
        NettyClientReplies.RequestEntry requestEntry = new NettyClientReplies.RequestEntry(request);
        LOG.debug("{}: write begin {}", (Object)this, (Object)request);
        NettyClientReplies.ReplyMap replyMap2 = replyMap;
        synchronized (replyMap2) {
            channel = this.connection.getChannelUninterruptibly();
            if (channel == null) {
                f.completeExceptionally(new AlreadyClosedException(this + ": Failed to send " + request));
                return f;
            }
            replyEntry = replyMap.submitRequest(requestEntry, isClose, f);
            Function<DataStreamRequest, ChannelFuture> writeMethod = this.outstandingRequests.shouldFlush(this.flushRequestCountMin, this.flushRequestBytesMin, request) ? channel::writeAndFlush : channel::write;
            channelFuture = writeMethod.apply(request);
        }
        channelFuture.addListener((GenericFutureListener<? extends Future<? super Void>>)((GenericFutureListener<Future>)future -> {
            if (!future.isSuccess()) {
                IOException e = new IOException(this + ": Failed to send " + request + " to " + channel.remoteAddress(), future.cause());
                f.completeExceptionally(e);
                replyMap.fail(requestEntry);
                LOG.error("Channel write failed", e);
            } else {
                LOG.debug("{}: write after {}", (Object)this, (Object)request);
                TimeDuration timeout = isClose ? this.closeTimeout : this.requestTimeout;
                replyEntry.scheduleTimeout(() -> channel.eventLoop().schedule(() -> {
                    if (!f.isDone()) {
                        f.completeExceptionally(new TimeoutIOException("Timeout " + timeout + ": Failed to send " + request + " via channel " + channel));
                        replyMap.fail(requestEntry);
                    }
                }, timeout.getDuration(), timeout.getUnit()));
            }
        }));
        return f;
    }

    @Override
    public void close() {
        boolean flush = this.outstandingRequests.shouldFlush(0, SizeInBytes.ZERO, null);
        if (flush) {
            Optional.ofNullable(this.connection.getChannelUninterruptibly()).map(c -> c.writeAndFlush(DataStreamPacketByteBuffer.EMPTY_BYTE_BUFFER)).ifPresent(f -> f.addListener((GenericFutureListener<? extends Future<? super Void>>)((GenericFutureListener<Future>)dummy -> this.connection.close())));
        } else {
            this.connection.close();
        }
    }

    public String toString() {
        return this.name;
    }

    @ChannelHandler.Sharable
    static class EncoderByteBuffer
    extends MessageToMessageEncoder<ByteBuffer> {
        EncoderByteBuffer() {
        }

        @Override
        protected void encode(ChannelHandlerContext ctx, ByteBuffer request, List<Object> out) {
            NettyDataStreamUtils.encodeByteBuffer(request, out::add);
        }
    }

    @ChannelHandler.Sharable
    static class EncoderFilePositionCount
    extends MessageToMessageEncoder<DataStreamRequestFilePositionCount> {
        EncoderFilePositionCount() {
        }

        @Override
        protected void encode(ChannelHandlerContext ctx, DataStreamRequestFilePositionCount request, List<Object> out) {
            NettyDataStreamUtils.encodeDataStreamRequestFilePositionCount(request, out::add, ctx.alloc());
        }
    }

    @ChannelHandler.Sharable
    static class EncoderByteBuf
    extends MessageToMessageEncoder<DataStreamRequestByteBuf> {
        EncoderByteBuf() {
        }

        @Override
        protected void encode(ChannelHandlerContext context, DataStreamRequestByteBuf request, List<Object> out) {
            NettyDataStreamUtils.encodeDataStreamRequestByteBuf(request, out::add, context.alloc());
        }
    }

    @ChannelHandler.Sharable
    static class Encoder
    extends MessageToMessageEncoder<DataStreamRequestByteBuffer> {
        Encoder() {
        }

        @Override
        protected void encode(ChannelHandlerContext context, DataStreamRequestByteBuffer request, List<Object> out) {
            NettyDataStreamUtils.encodeDataStreamRequestByteBuffer(request, out::add, context.alloc());
        }
    }

    static class OutstandingRequests {
        private int count;
        private long bytes;

        OutstandingRequests() {
        }

        private boolean shouldFlush(List<WriteOption> options, int countMin, SizeInBytes bytesMin) {
            if (options.contains(StandardWriteOption.CLOSE)) {
                return true;
            }
            if (this.bytes == 0L && this.count == 0) {
                return false;
            }
            return this.count >= countMin || this.bytes >= bytesMin.getSize() || options.contains(StandardWriteOption.FLUSH);
        }

        synchronized boolean shouldFlush(int countMin, SizeInBytes bytesMin, DataStreamRequest request) {
            List<WriteOption> options;
            if (request == null) {
                options = Collections.emptyList();
            } else {
                options = request.getWriteOptionList();
                ++this.count;
                long length = request.getDataLength();
                Preconditions.assertTrue(length >= 0L, () -> "length = " + length + " < 0, request: " + request);
                this.bytes += length;
            }
            boolean flush = this.shouldFlush(options, countMin, bytesMin);
            LOG.debug("flush? {}, (count, bytes)=({}, {}), min=({}, {}), request={}, options={}", flush, this.count, this.bytes, countMin, bytesMin, request, options);
            if (flush) {
                this.count = 0;
                this.bytes = 0L;
            }
            return flush;
        }
    }

    static class Connection {
        static final TimeDuration RECONNECT = TimeDuration.valueOf(100L, TimeUnit.MILLISECONDS);
        private final InetSocketAddress address;
        private final WorkerGroupGetter workerGroup;
        private final Supplier<ChannelInitializer<SocketChannel>> channelInitializerSupplier;
        private final AtomicReference<MemoizedSupplier<ChannelFuture>> ref;

        Connection(InetSocketAddress address, WorkerGroupGetter workerGroup, Supplier<ChannelInitializer<SocketChannel>> channelInitializerSupplier) {
            this.address = address;
            this.workerGroup = workerGroup;
            this.channelInitializerSupplier = channelInitializerSupplier;
            this.ref = new AtomicReference<MemoizedSupplier<ChannelFuture>>(MemoizedSupplier.valueOf(this::connect));
        }

        ChannelFuture getChannelFuture() {
            Supplier referenced = this.ref.get();
            return referenced != null ? (ChannelFuture)referenced.get() : null;
        }

        Channel getChannelUninterruptibly() {
            ChannelFuture future = this.getChannelFuture();
            if (future == null) {
                return null;
            }
            Channel channel = future.syncUninterruptibly().channel();
            if (channel.isActive()) {
                return channel;
            }
            ChannelFuture f = this.reconnect();
            return f == null ? null : f.syncUninterruptibly().channel();
        }

        private EventLoopGroup getWorkerGroup() {
            return this.workerGroup.get();
        }

        private ChannelFuture connect() {
            if (this.isClosed()) {
                return null;
            }
            return ((Bootstrap)((Bootstrap)((Bootstrap)((Bootstrap)new Bootstrap().group(this.getWorkerGroup())).channel(NettyUtils.getSocketChannelClass(this.getWorkerGroup()))).handler(this.channelInitializerSupplier.get())).option(ChannelOption.SO_KEEPALIVE, true)).connect(this.address).addListener(new ChannelFutureListener(){

                @Override
                public void operationComplete(ChannelFuture future) {
                    if (!future.isSuccess()) {
                        this.scheduleReconnect(this + " failed", future.cause());
                    } else {
                        LOG.trace("{} succeed.", (Object)this);
                    }
                }
            });
        }

        void scheduleReconnect(String message, Throwable cause) {
            if (this.isClosed()) {
                return;
            }
            LOG.warn("{}: {}; schedule reconnecting to {} in {}", this, message, this.address, RECONNECT);
            if (cause != null) {
                LOG.warn("", cause);
            }
            this.getWorkerGroup().schedule(this::reconnect, RECONNECT.getDuration(), RECONNECT.getUnit());
        }

        private synchronized ChannelFuture reconnect() {
            Channel channel;
            ChannelFuture channelFuture = this.getChannelFuture();
            if (channelFuture != null && (channel = channelFuture.syncUninterruptibly().channel()).isActive()) {
                return channelFuture;
            }
            MemoizedSupplier<MemoizedSupplier> supplier = MemoizedSupplier.valueOf(() -> MemoizedSupplier.valueOf(this::connect));
            MemoizedSupplier<ChannelFuture> previous = this.ref.getAndUpdate(prev -> prev == null ? null : (MemoizedSupplier)supplier.get());
            if (previous != null && previous.isInitialized()) {
                previous.get().channel().close();
            }
            return this.getChannelFuture();
        }

        void close() {
            MemoizedSupplier previous = this.ref.getAndSet(null);
            if (previous != null && previous.isInitialized()) {
                ((ChannelFuture)previous.get()).channel().close().addListener((GenericFutureListener<? extends Future<? super Void>>)((GenericFutureListener<Future>)future -> this.workerGroup.shutdownGracefully()));
            } else {
                this.workerGroup.shutdownGracefully();
            }
        }

        boolean isClosed() {
            return this.ref.get() == null;
        }

        public String toString() {
            return JavaUtils.getClassSimpleName(this.getClass()) + "-" + this.address;
        }
    }

    private static class WorkerGroupGetter
    implements Supplier<EventLoopGroup> {
        private static final AtomicReference<CompletableFuture<ReferenceCountedObject<EventLoopGroup>>> SHARED_WORKER_GROUP = new AtomicReference();
        private final EventLoopGroup workerGroup;

        static WorkerGroupGetter newInstance(RaftProperties properties) {
            boolean shared = NettyConfigKeys.DataStream.Client.workerGroupShare(properties);
            if (shared) {
                CompletableFuture<ReferenceCountedObject<EventLoopGroup>> created = new CompletableFuture<ReferenceCountedObject<EventLoopGroup>>();
                final CompletableFuture<ReferenceCountedObject<EventLoopGroup>> current = SHARED_WORKER_GROUP.updateAndGet(g2 -> g2 != null ? g2 : created);
                if (current == created) {
                    created.complete(ReferenceCountedObject.wrap(WorkerGroupGetter.newWorkerGroup(properties)));
                }
                return new WorkerGroupGetter(current.join().retain()){

                    @Override
                    void shutdownGracefully() {
                        CompletableFuture returned = (CompletableFuture)SHARED_WORKER_GROUP.updateAndGet(previous -> {
                            Preconditions.assertSame(current, previous, "SHARED_WORKER_GROUP");
                            return ((ReferenceCountedObject)previous.join()).release() ? null : previous;
                        });
                        if (returned == null) {
                            this.get().shutdownGracefully();
                        }
                    }
                };
            }
            return new WorkerGroupGetter(WorkerGroupGetter.newWorkerGroup(properties));
        }

        static EventLoopGroup newWorkerGroup(RaftProperties properties) {
            return NettyUtils.newEventLoopGroup(JavaUtils.getClassSimpleName(NettyClientStreamRpc.class) + "-workerGroup", NettyConfigKeys.DataStream.Client.workerGroupSize(properties), NettyConfigKeys.DataStream.Client.useEpoll(properties));
        }

        private WorkerGroupGetter(EventLoopGroup workerGroup) {
            this.workerGroup = workerGroup;
        }

        @Override
        public final EventLoopGroup get() {
            return this.workerGroup;
        }

        void shutdownGracefully() {
            this.workerGroup.shutdownGracefully();
        }
    }
}

