// Accepts HTTP-encoded SCTP streams and forwards them to a TCP listener.

package main

import (
	"errors"
	"flag"
	"fmt"
	"io"
	"io/ioutil"
	"log"
	"net"
	"net/http"
	"os"
	"sync"
	"time"

	"github.com/pion/logging"
	"github.com/pion/sctp"
)

const (
	// How long to keep idle connections before closing and discarding them.
	connMapTimeout = 1 * time.Minute
	// Size of send and receive queues.
	queueCapacity = 10
)

// A clientID is like a return address attached to each packet we receive. It is
// just a random string carried in an HTTP request. When we receive an HTTP
// request marked with a particular clientID, we know that we can send
// downstream packets related to that clientID in the HTTP response.
type clientID string

// net.Addr interface for clientID.

func (addr clientID) Network() string {
	return "clientid"
}

func (addr clientID) String() string {
	return string(addr)
}

// queueConn implements the net.Conn interface by storing queues of packets.
// QueueIncoming queues a packet for a future call to Read. Write queues a
// packet that can later be read from OutgoingQueue().
type queueConn struct {
	closed     chan struct{}
	recvQueue  chan []byte
	sendQueue  chan []byte
	remoteAddr net.Addr
}

func newQueueConn(remoteAddr net.Addr) *queueConn {
	return &queueConn{
		closed:     make(chan struct{}),
		recvQueue:  make(chan []byte, queueCapacity),
		sendQueue:  make(chan []byte, queueCapacity),
		remoteAddr: remoteAddr,
	}
}

// Queue an incoming packet for a future call to Read. Takes ownership of p.
func (conn *queueConn) QueueIncoming(p []byte) {
	select {
	case conn.recvQueue <- p:
		// Packet queued.
	default:
		// OK to drop packet when queue is full.
	}
}

// Return the queue of outgoing packets (previously sent by Write).
func (conn *queueConn) OutgoingQueue() chan []byte {
	return conn.sendQueue
}

var errClosedPacketConn = errors.New("closed conn")

// Return a packet and address previously queued by QueueIncoming.
func (conn *queueConn) Read(p []byte) (int, error) {
	select {
	case packet := <-conn.recvQueue:
		return copy(p, packet), nil
	case <-conn.closed:
		return 0, &net.OpError{Op: "read", Net: conn.RemoteAddr().Network(), Source: conn.LocalAddr(), Addr: conn.RemoteAddr(), Err: errClosedPacketConn}
	}
}

// Queue an outgoing packet.
func (conn *queueConn) Write(p []byte) (int, error) {
	// Copy the slice, because the caller may reuse it.
	c := make([]byte, len(p))
	copy(c, p)
	select {
	case conn.OutgoingQueue() <- c:
		// Packet queued.
	case <-conn.closed:
		return 0, &net.OpError{Op: "write", Net: conn.RemoteAddr().Network(), Source: conn.LocalAddr(), Addr: conn.RemoteAddr(), Err: errClosedPacketConn}
	default:
		// OK to drop packet when queue is full.
	}
	return len(c), nil
}

// Close the queueConn. Read and Write calls will become unblocked and return
// errors.
func (conn *queueConn) Close() error {
	select {
	case <-conn.closed:
		return &net.OpError{Op: "close", Net: conn.LocalAddr().Network(), Addr: conn.LocalAddr(), Err: errClosedPacketConn}
	default:
		close(conn.closed)
		return nil
	}
}

// An empty struct that satisfies the net.Addr interface.
type dummyAddr struct{}

func (addr dummyAddr) Network() string {
	return "local"
}

func (addr dummyAddr) String() string {
	return "local"
}

func (conn *queueConn) LocalAddr() net.Addr {
	return dummyAddr{}
}

func (conn *queueConn) RemoteAddr() net.Addr {
	return conn.remoteAddr
}

func (conn *queueConn) SetDeadline(t time.Time) error {
	return errors.New("not implemented")
}

func (conn *queueConn) SetReadDeadline(t time.Time) error {
	return errors.New("not implemented")
}

func (conn *queueConn) SetWriteDeadline(t time.Time) error {
	return errors.New("not implemented")
}

func handleStream(stream *sctp.Stream, forwardAddr string) error {
	c, err := net.Dial("tcp", forwardAddr)
	if err != nil {
		return err
	}
	conn := c.(*net.TCPConn)
	defer conn.Close()

	var wg sync.WaitGroup
	wg.Add(2)
	go func() {
		defer wg.Done()
		_, err := io.Copy(conn, stream)
		if err != nil {
			log.Printf("stream %v recv err: %v", stream.StreamIdentifier(), err)
		}
		log.Printf("stream %v recv done", stream.StreamIdentifier())
		err = conn.CloseWrite()
		if err != nil {
			log.Printf("stream %v CloseWrite err: %v", stream.StreamIdentifier(), err)
		}
	}()
	go func() {
		defer wg.Done()
		_, err := io.Copy(stream, conn)
		if err != nil {
			log.Printf("stream %v send err: %v", stream.StreamIdentifier(), err)
		}
		log.Printf("stream %v send done", stream.StreamIdentifier())
		err = conn.CloseRead()
		if err != nil {
			log.Printf("stream %v CloseRead err: %v", stream.StreamIdentifier(), err)
		}
	}()
	wg.Wait()
	return nil
}

func acceptStreams(assoc *sctp.Association, forwardAddr string) error {
	for {
		stream, err := assoc.AcceptStream()
		if err != nil {
			if err, ok := err.(*net.OpError); ok && err.Temporary() {
				log.Printf("temporary error in assoc.AcceptStream: %v", err)
				continue
			}
			if err == io.EOF {
				err = nil
			}
			return err
		}

		go func() {
			defer stream.Close()
			log.Printf("begin stream %v", stream.StreamIdentifier())
			err := handleStream(stream, forwardAddr)
			if err != nil {
				log.Printf("error in handleStream: %v", err)
			}
			log.Printf("end stream %v", stream.StreamIdentifier())
		}()
	}
}

func acceptConns(ln *connMap, forwardAddr string) error {
	for {
		conn, err := ln.Accept()
		if err != nil {
			if err, ok := err.(*net.OpError); ok && err.Temporary() {
				log.Printf("temporary error in ln.Accept: %v", err)
				continue
			}
			return err
		}

		go func() {
			defer conn.Close()

			assoc, err := sctp.Server(sctp.Config{
				NetConn:       conn,
				LoggerFactory: logging.NewDefaultLoggerFactory(),
			})
			if err != nil {
				log.Printf("error in sctp.Server: %v", err)
				return
			}
			defer assoc.Close()

			log.Printf("begin association %v", conn.RemoteAddr())
			err = acceptStreams(assoc, forwardAddr)
			if err != nil {
				log.Printf("error in acceptStreams: %v", err)
			}
			log.Printf("end association %v", conn.RemoteAddr())
		}()
	}
}

// An HTTP request handler that reads incoming packets from request bodies and
// maps them to a queueConn, and writes outgoing packets that pertain to the
// queueConn in response bodies.
type httpHandler struct{ conns *connMap }

func (h *httpHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
	defer req.Body.Close()

	var remoteAddr net.Addr
	if client := req.Header.Get("Client-Id"); client != "" {
		remoteAddr = clientID(client)
	} else {
		rw.WriteHeader(http.StatusBadRequest)
		return
	}

	conn := h.conns.Get(remoteAddr, time.Now())

	buf, err := ioutil.ReadAll(io.LimitReader(req.Body, 2000))
	// len(buf) == 0 is a special case; it means no packet rather
	// than a zero-length packet.
	if len(buf) > 0 {
		conn.QueueIncoming(buf)
	}
	if err != nil && err != io.EOF {
		return
	}

	// Send a packet in the response if available.
	rw.Header().Set("Content-Type", "application/octet-stream")
	select {
	case buf := <-conn.OutgoingQueue():
		rw.Write(buf)
	default:
	}
}

func run(listenAddr, forwardAddr string) error {
	conns := newConnMap(connMapTimeout)
	h := &httpHandler{conns}
	http.Handle("/", h)
	server := &http.Server{
		Addr: listenAddr,
	}
	go func() {
		err := server.ListenAndServe()
		if err != nil {
			log.Printf("error in ListenAndServe: %v", err)
		}
	}()

	return acceptConns(conns, forwardAddr)
}

func main() {
	log.SetFlags(log.LstdFlags | log.LUTC)

	flag.Parse()
	if flag.NArg() != 2 {
		fmt.Fprintf(os.Stderr, "usage: %s LISTENADDR FORWARDADDR\n", os.Args[0])
		os.Exit(1)
	}
	listenAddr := flag.Arg(0)
	forwardAddr := flag.Arg(1)

	err := run(listenAddr, forwardAddr)
	if err != nil {
		log.Println(err)
		os.Exit(1)
	}
}
