马大波 1 gadu atpakaļ
vecāks
revīzija
1deeb4d9e2

+ 25 - 6
src/main/java/com/xiaobao/gateway/protocol/dto/MediaServerDTO.java

@@ -6,13 +6,23 @@ import lombok.NoArgsConstructor;
 import lombok.Setter;
 
 import java.util.ArrayList;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 
 @Getter
 @Setter
 @NoArgsConstructor
 @AllArgsConstructor
 public class MediaServerDTO {
+    private static final List<MediaServerDTO> MOCK_MEDIA_SERVERS = new ArrayList<>();
+
+    static {
+        MOCK_MEDIA_SERVERS.add(new MediaServerDTO("192.168.66.73", 1935, 8000)); // 媒体服务器
+        MOCK_MEDIA_SERVERS.add(new MediaServerDTO("192.168.66.114", 1935, 8001)); // 媒体服务器
+        MOCK_MEDIA_SERVERS.add(new MediaServerDTO("127.0.0.1", 8007, 8002)); // Echo服务器
+    }
+
     /**
      * 服务器地址
      */
@@ -32,11 +42,20 @@ public class MediaServerDTO {
      *
      * @return 服务器列表
      */
-    public static List<MediaServerDTO> getMediaServerList() {
-        List<MediaServerDTO> serverList = new ArrayList<>();
-        serverList.add(new MediaServerDTO("192.168.66.73", 1935, 8000)); // 本地服务器
-        serverList.add(new MediaServerDTO("39.101.185.102", 1935, 8001)); // 公网服务器
-        // serverList.add(new MediaServerDTO("127.0.0.1", 8007, 8002)); // Echo服务器
-        return serverList;
+    public static List<MediaServerDTO> getServerList() {
+        return MOCK_MEDIA_SERVERS;
+    }
+
+    /**
+     * 获取媒体服务器路由表
+     *
+     * @return 路由表
+     */
+    public static Map<Integer, MediaServerDTO> getRouteTable() {
+        Map<Integer, MediaServerDTO> table = new HashMap<>();
+        for (MediaServerDTO server : MOCK_MEDIA_SERVERS) {
+            table.put(server.getProxyPort(), server);
+        }
+        return table;
     }
 }

+ 4 - 15
src/main/java/com/xiaobao/gateway/protocol/proxy/forward/ForwardProxyClientInHandler.java

@@ -9,11 +9,11 @@ import io.netty.channel.ChannelInboundHandlerAdapter;
 
 public class ForwardProxyClientInHandler extends ChannelInboundHandlerAdapter {
     private final Channel inboundChannel;
-    private final MediaServerDTO mediaServer;
+    private final MediaServerDTO server;
 
-    public ForwardProxyClientInHandler(Channel inboundChannel, MediaServerDTO mediaServer) {
+    public ForwardProxyClientInHandler(Channel inboundChannel, MediaServerDTO server) {
         this.inboundChannel = inboundChannel;
-        this.mediaServer = mediaServer;
+        this.server = server;
     }
 
     @Override
@@ -29,25 +29,14 @@ public class ForwardProxyClientInHandler extends ChannelInboundHandlerAdapter {
     public void channelRead(final ChannelHandlerContext ctx, Object msg) {
         if (msg instanceof ByteBuf) {
             ByteBuf in = (ByteBuf) msg;
-            if (in.readableBytes() < 6) {
-                return;
-            }
-            // 标记读索引位置
-            in.markReaderIndex();
             // 读取消息长度(4 字节)
             int len = in.readInt();
             // 读取消息类型(2 字节)
             short identifier = in.readShort();
-            if (mediaServer == null || mediaServer.getProxyPort() != identifier) {
+            if (server == null || server.getProxyPort() != identifier) {
                 // 没有匹配到直接丢弃
                 return;
             }
-            // 检查是否有足够的字节来读取消息体
-            if (in.readableBytes() < len) {
-                // 重置读索引
-                in.resetReaderIndex();
-                return;
-            }
             // 读取消息体
             ByteBuf body = ctx.alloc().buffer(len);
             in.readBytes(body);

+ 1 - 1
src/main/java/com/xiaobao/gateway/protocol/proxy/forward/ForwardProxyServer.java

@@ -36,7 +36,7 @@ public class ForwardProxyServer {
     private static final String REMOTE_HOST = "127.0.0.1";
 
     public static void main(String[] args) throws InterruptedException {
-        List<MediaServerDTO> serverList = MediaServerDTO.getMediaServerList();
+        List<MediaServerDTO> serverList = MediaServerDTO.getServerList();
         ExecutorService executorService = Executors.newFixedThreadPool(serverList.size());
         for (MediaServerDTO server : serverList) {
             executorService.submit(() -> {

+ 2 - 17
src/main/java/com/xiaobao/gateway/protocol/proxy/reverse/ReverseProxyClientInHandler.java

@@ -17,12 +17,8 @@ public class ReverseProxyClientInHandler extends ChannelInboundHandlerAdapter {
     }
 
     @Override
-    public void channelActive(ChannelHandlerContext ctx) {
-        if (!inboundChannel.isActive()) {
-            ReverseProxyServerInHandler.closeOnFlush(ctx.channel());
-        } else {
-            ctx.read();
-        }
+    public void channelActive(ChannelHandlerContext ctx) throws Exception {
+        ctx.read();
     }
 
     @Override
@@ -46,15 +42,4 @@ public class ReverseProxyClientInHandler extends ChannelInboundHandlerAdapter {
             }
         });
     }
-
-    @Override
-    public void channelInactive(ChannelHandlerContext ctx) {
-        ReverseProxyServerInHandler.closeOnFlush(inboundChannel);
-    }
-
-    @Override
-    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
-        cause.printStackTrace();
-        ReverseProxyServerInHandler.closeOnFlush(ctx.channel());
-    }
 }

+ 0 - 50
src/main/java/com/xiaobao/gateway/protocol/proxy/reverse/ReverseProxyHolder.java

@@ -1,50 +0,0 @@
-package com.xiaobao.gateway.protocol.proxy.reverse;
-
-import com.xiaobao.gateway.protocol.dto.MediaServerDTO;
-import io.netty.bootstrap.Bootstrap;
-import io.netty.channel.*;
-import io.netty.channel.socket.SocketChannel;
-import lombok.Getter;
-import lombok.NoArgsConstructor;
-
-import java.util.List;
-import java.util.Map;
-import java.util.concurrent.ConcurrentHashMap;
-
-@Getter
-@NoArgsConstructor
-public class ReverseProxyHolder {
-    private final Map<Integer, Channel> connections = new ConcurrentHashMap<>();
-
-    public ReverseProxyHolder(ChannelHandlerContext ctx) {
-        Channel inboundChannel = ctx.channel();
-        List<MediaServerDTO> serverList = MediaServerDTO.getMediaServerList();
-        for (MediaServerDTO server : serverList) {
-            Bootstrap b = new Bootstrap();
-            b.group(inboundChannel.eventLoop())
-                .channel(inboundChannel.getClass())
-                .handler(new ChannelInitializer<SocketChannel>() {
-                    @Override
-                    protected void initChannel(SocketChannel ch) throws Exception {
-                        ChannelPipeline pipeline = ch.pipeline();
-                        pipeline.addLast(new ReverseProxyClientInHandler(inboundChannel, server));
-                    }
-                })
-                .option(ChannelOption.AUTO_READ, false);
-            ChannelFuture f = b.connect(server.getRemoteHost(), server.getRemotePort());
-            Channel outboundChannel = f.channel();
-            connections.put(server.getProxyPort(), outboundChannel);
-            f.addListener((ChannelFutureListener) future -> {
-                if (future.isSuccess()) {
-                    inboundChannel.read();
-                } else {
-                    inboundChannel.close();
-                }
-            });
-        }
-    }
-
-    public Channel getChannel(int identifier) {
-        return connections.get(identifier);
-    }
-}

+ 45 - 49
src/main/java/com/xiaobao/gateway/protocol/proxy/reverse/ReverseProxyServerInHandler.java

@@ -1,46 +1,67 @@
 package com.xiaobao.gateway.protocol.proxy.reverse;
 
+import com.xiaobao.gateway.protocol.dto.MediaServerDTO;
+import io.netty.bootstrap.Bootstrap;
 import io.netty.buffer.ByteBuf;
-import io.netty.buffer.Unpooled;
-import io.netty.channel.Channel;
-import io.netty.channel.ChannelFutureListener;
-import io.netty.channel.ChannelHandlerContext;
-import io.netty.channel.ChannelInboundHandlerAdapter;
+import io.netty.channel.*;
+import io.netty.channel.socket.SocketChannel;
 
 import java.util.Map;
 
 public class ReverseProxyServerInHandler extends ChannelInboundHandlerAdapter {
-    private ReverseProxyHolder reverseProxyHolder;
+    private static final Map<Integer, MediaServerDTO> mediaRouteTable = MediaServerDTO.getRouteTable();
+
+    private Channel outboundChannel;
 
     @Override
-    public void channelActive(ChannelHandlerContext ctx) {
-        this.reverseProxyHolder = new ReverseProxyHolder(ctx);
+    public void channelActive(ChannelHandlerContext ctx) throws Exception {
+        ctx.read();
     }
 
     @Override
     public void channelRead(final ChannelHandlerContext ctx, Object msg) {
-        if (msg instanceof ByteBuf && reverseProxyHolder != null) {
+        if (msg instanceof ByteBuf) {
             ByteBuf in = (ByteBuf) msg;
-            if (in.readableBytes() < 6) {
-                return;
-            }
-            // 标记读索引位置
-            in.markReaderIndex();
             // 读取消息长度(4 字节)
             int len = in.readInt();
             // 读取消息类型(2 字节)
-            short identifier = in.readShort();
-            // 获取路由通道
-            Channel outboundChannel = reverseProxyHolder.getChannel(identifier);
+            int identifier = in.readShort();
             if (outboundChannel == null) {
-                ctx.fireChannelRead(msg);
-            } else {
-                // 检查是否有足够的字节来读取消息体
-                if (in.readableBytes() < len) {
-                    // 重置读索引
-                    in.resetReaderIndex();
-                    return;
+                if (mediaRouteTable.containsKey(identifier)) {
+                    MediaServerDTO server = mediaRouteTable.get(identifier);
+                    Channel inboundChannel = ctx.channel();
+                    Bootstrap b = new Bootstrap();
+                    b.group(inboundChannel.eventLoop())
+                        .channel(inboundChannel.getClass())
+                        .handler(new ChannelInitializer<SocketChannel>() {
+                            @Override
+                            protected void initChannel(SocketChannel ch) throws Exception {
+                                ChannelPipeline pipeline = ch.pipeline();
+                                pipeline.addLast(new ReverseProxyClientInHandler(inboundChannel, server));
+                            }
+                        })
+                        .option(ChannelOption.AUTO_READ, false);
+                    ChannelFuture f = b.connect(server.getRemoteHost(), server.getRemotePort());
+                    outboundChannel = f.channel();
+                    f.addListener((ChannelFutureListener) future -> {
+                        if (future.isSuccess()) {
+                            ByteBuf body = ctx.alloc().buffer(len);
+                            in.readBytes(body);
+                            outboundChannel.writeAndFlush(body).addListener((ChannelFutureListener) writeFuture -> {
+                                if (writeFuture.isSuccess()) {
+                                    inboundChannel.read();
+                                } else {
+                                    future.channel().close();
+                                }
+                            });
+                        } else {
+                            inboundChannel.close();
+                        }
+                    });
+                } else {
+                    ctx.close();
                 }
+            } else {
                 // 读取消息体
                 ByteBuf body = ctx.alloc().buffer(len);
                 in.readBytes(body);
@@ -58,29 +79,4 @@ public class ReverseProxyServerInHandler extends ChannelInboundHandlerAdapter {
             ctx.fireChannelRead(msg);
         }
     }
-
-    @Override
-    public void channelInactive(ChannelHandlerContext ctx) {
-        if (reverseProxyHolder != null) {
-            Map<Integer, Channel> connections = reverseProxyHolder.getConnections();
-            for (Channel outboundChannel : connections.values()) {
-                closeOnFlush(outboundChannel);
-            }
-        }
-    }
-
-    @Override
-    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
-        cause.printStackTrace();
-        closeOnFlush(ctx.channel());
-    }
-
-    /**
-     * Closes the specified channel after all queued write requests are flushed.
-     */
-    static void closeOnFlush(Channel ch) {
-        if (ch.isActive()) {
-            ch.writeAndFlush(Unpooled.EMPTY_BUFFER).addListener(ChannelFutureListener.CLOSE);
-        }
-    }
 }