// Accepts HTTP-encoded smux-over-KCP streams and forwards them to a TCP
// listener.

package main

import (
	"errors"
	"flag"
	"fmt"
	"io"
	"io/ioutil"
	"log"
	"net"
	"net/http"
	"os"
	"sync"
	"time"

	"github.com/xtaci/kcp-go"
	"github.com/xtaci/smux/v2"
)

const (
	// How long to keep outgoing packets in a client-specific queue before
	// discarding them.
	clientMapTimeout = 60 * time.Second
	// Size of send and receive queues.
	queueCapacity = 10
)

// A clientID is like a return address attached to each packet we receive. It is
// just a random string carried in an HTTP request. When we receive an HTTP
// request marked with a particular clientID, we know that we can send
// downstream packets related to that clientID in the HTTP response.
type clientID string

// net.Addr interface for clientID.

func (addr clientID) Network() string {
	return "clientid"
}

func (addr clientID) String() string {
	return string(addr)
}

// A packet tagged with an address (encapsulating the return type of
// PacketConn.ReadFrom).
type taggedPacket struct {
	P    []byte
	Addr net.Addr
}

// queuePacketConn implements the net.PacketConn interface by storing queues of
// packets. QueueIncoming queues a packet for a future call to ReadFrom. WriteTo
// manages a set of address-specific queues, which can later be accessed using
// OutgoingQueue.
type queuePacketConn struct {
	closed    chan struct{}
	recvQueue chan taggedPacket
	clients   *clientMap
	localAddr net.Addr
}

func newQueuePacketConn(localAddr net.Addr) *queuePacketConn {
	return &queuePacketConn{
		closed:    make(chan struct{}),
		recvQueue: make(chan taggedPacket, queueCapacity),
		clients:   newClientMap(clientMapTimeout),
		localAddr: localAddr,
	}
}

// Queue an incoming packet (and its source address) for a future call to
// ReadFrom. Takes ownership of p.
func (conn *queuePacketConn) QueueIncoming(p []byte, addr net.Addr) {
	select {
	case conn.recvQueue <- taggedPacket{p, addr}:
		// Packet queued.
	default:
		// OK to drop packet when queue is full.
	}
}

// Return a queue of outgoing packets addressed to addr (previously sent by
// WriteTo).
func (conn *queuePacketConn) OutgoingQueue(addr net.Addr) chan []byte {
	return conn.clients.SendQueue(addr, time.Now())
}

var errClosedPacketConn = errors.New("closed conn")

// Return a packet and address previously queued by QueueIncoming.
func (conn *queuePacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
	select {
	case packet := <-conn.recvQueue:
		return copy(p, packet.P), packet.Addr, nil
	case <-conn.closed:
		return 0, nil, &net.OpError{Op: "read", Net: conn.LocalAddr().Network(), Source: conn.LocalAddr(), Err: errClosedPacketConn}
	}
}

// Queue an outgoing packet for the given address. The address-specific queue
// can later by retrieved using OutgoingQueue.
func (conn *queuePacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
	// Copy the slice, because the caller may reuse it.
	c := make([]byte, len(p))
	copy(c, p)
	select {
	case conn.OutgoingQueue(addr) <- c:
		// Packet queued.
	case <-conn.closed:
		return 0, &net.OpError{Op: "write", Net: addr.Network(), Source: conn.LocalAddr(), Addr: addr, Err: errClosedPacketConn}
	default:
		// OK to drop packet when queue is full.
	}
	return len(c), nil
}

// Close the queuePacketConn. ReadFrom and WriteTo calls will become unblocked
// and return errors.
func (conn *queuePacketConn) Close() error {
	select {
	case <-conn.closed:
		return &net.OpError{Op: "close", Net: conn.LocalAddr().Network(), Addr: conn.LocalAddr(), Err: errClosedPacketConn}
	default:
		close(conn.closed)
		return nil
	}
}

func (conn *queuePacketConn) LocalAddr() net.Addr {
	return conn.localAddr
}

func (conn *queuePacketConn) SetDeadline(t time.Time) error {
	return errors.New("not implemented")
}

func (conn *queuePacketConn) SetReadDeadline(t time.Time) error {
	return errors.New("not implemented")
}

func (conn *queuePacketConn) SetWriteDeadline(t time.Time) error {
	return errors.New("not implemented")
}

// An HTTP request handler that reads incoming packets from request bodies to
// feed them to a queuePacketConn, and takes outgoing packets from the
// queuePacketConn to include them in response bodies.
type httpHandler struct{ conn *queuePacketConn }

func (h *httpHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
	defer req.Body.Close()

	var remoteAddr net.Addr
	if client := req.Header.Get("Client-Id"); client != "" {
		remoteAddr = clientID(client)
	} else {
		rw.WriteHeader(http.StatusBadRequest)
		return
	}

	buf, err := ioutil.ReadAll(io.LimitReader(req.Body, 2000))
	// len(buf) == 0 is a special case; it means no packet rather
	// than a zero-length packet.
	if len(buf) > 0 {
		h.conn.QueueIncoming(buf, remoteAddr)
	}
	if err != nil && err != io.EOF {
		return
	}

	// Send a packet in the response if available.
	rw.Header().Set("Content-Type", "application/octet-stream")
	select {
	case buf := <-h.conn.OutgoingQueue(remoteAddr):
		rw.Write(buf)
	default:
	}
}

func handleStream(stream *smux.Stream, forwardAddr string) error {
	c, err := net.Dial("tcp", forwardAddr)
	if err != nil {
		return err
	}
	conn := c.(*net.TCPConn)
	defer conn.Close()

	var wg sync.WaitGroup
	wg.Add(2)
	go func() {
		defer wg.Done()
		_, err := io.Copy(conn, stream)
		if err != nil {
			log.Printf("stream %v (session %v) recv err: %v", stream.ID(), stream.RemoteAddr(), err)
		}
		log.Printf("stream %v (session %v) recv done", stream.ID(), stream.RemoteAddr())
		err = conn.CloseWrite()
		if err != nil {
			log.Printf("stream %v (session %v) CloseWrite err: %v", stream.ID(), stream.RemoteAddr(), err)
		}
	}()
	go func() {
		defer wg.Done()
		_, err := io.Copy(stream, conn)
		if err != nil {
			log.Printf("stream %v (session %v) send err: %v", stream.ID(), stream.RemoteAddr(), err)
		}
		log.Printf("stream %v (session %v) send done", stream.ID(), stream.RemoteAddr())
		err = conn.CloseRead()
		if err != nil {
			log.Printf("stream %v (session %v) CloseRead err: %v", stream.ID(), stream.RemoteAddr(), err)
		}
	}()
	wg.Wait()
	return nil
}

func acceptStreams(sess *smux.Session, forwardAddr string) error {
	for {
		stream, err := sess.AcceptStream()
		if err != nil {
			if err, ok := err.(*net.OpError); ok && err.Temporary() {
				log.Printf("temporary error in sess.AcceptStream: %v", err)
				continue
			}
			return err
		}

		go func() {
			defer stream.Close()
			log.Printf("begin stream %v (session %v)", stream.ID(), stream.RemoteAddr())
			err := handleStream(stream, forwardAddr)
			if err != nil {
				log.Printf("error in handleStream: %v", err)
			}
			log.Printf("end stream %v (session %v)", stream.ID(), stream.RemoteAddr())
		}()
	}
}

func acceptSessions(ln *kcp.Listener, forwardAddr string) error {
	for {
		conn, err := ln.Accept()
		if err != nil {
			if err, ok := err.(*net.OpError); ok && err.Temporary() {
				log.Printf("temporary error in ln.Accept: %v", err)
				continue
			}
			return err
		}

		go func() {
			defer conn.Close()

			sess, err := smux.Server(conn, smux.DefaultConfig())
			if err != nil {
				log.Printf("error in smux.Server: %v", err)
				return
			}
			defer sess.Close()

			log.Printf("begin session %v", sess.RemoteAddr())
			err = acceptStreams(sess, forwardAddr)
			if err != nil {
				log.Printf("error in acceptStreams: %v", err)
			}
			log.Printf("end session %v", sess.RemoteAddr())
		}()
	}
}

func run(listenAddr, forwardAddr string) error {
	conn := newQueuePacketConn(nil)
	defer conn.Close()

	http.Handle("/", &httpHandler{conn})
	server := &http.Server{
		Addr: listenAddr,
	}
	go func() {
		err := server.ListenAndServe()
		if err != nil {
			log.Printf("error in ListenAndServe: %v", err)
		}
	}()

	ln, err := kcp.ServeConn(nil, 0, 0, conn)
	if err != nil {
		return err
	}
	defer ln.Close()

	return acceptSessions(ln, forwardAddr)
}

func main() {
	log.SetFlags(log.LstdFlags | log.LUTC)

	flag.Parse()
	if flag.NArg() != 2 {
		fmt.Fprintf(os.Stderr, "usage: %s LISTENADDR FORWARDADDR\n", os.Args[0])
		os.Exit(1)
	}
	listenAddr := flag.Arg(0)
	forwardAddr := flag.Arg(1)

	err := run(listenAddr, forwardAddr)
	if err != nil {
		log.Println(err)
		os.Exit(1)
	}
}
