// Forwards local TCP connections to a remote server over an HTTP-encoded SCTP
// association. Each local TCP connection becomes a stream within the
// association.

package main

import (
	"bytes"
	"crypto/rand"
	"encoding/hex"
	"errors"
	"flag"
	"fmt"
	"io"
	"io/ioutil"
	"log"
	"net"
	"net/http"
	"os"
	"sync"
	"time"

	"github.com/pion/logging"
	"github.com/pion/sctp"
)

// Size of receive queue.
const queueCapacity = 10

// clientID is analogous to an (IP address, port) source address tuple, a random
// string used to distinguish multiple clients. A clientID is the return value
// of httpConn.LocalAddr. The clientID is sent along with each packet sent
// through the HTTP tunnel. When the server receives the packet, encapsulated in
// an HTTP request, it can inspect the clientID to identify which of potentially
// many clients sent the packet -- just as it would be able to read the source
// address of an IP-encapsulated UDP packet it had received, for example. The
// most important function of the clientID is to allow the server to direct its
// downstream data; when the server receives an HTTP request with a clientID of
// "abcd", it knows that it may include downstream data intended for client
// "abcd" in the HTTP response.
type clientID []byte

const clientIDLength = 8

func newClientID() clientID {
	buf := make([]byte, clientIDLength)
	_, err := rand.Read(buf)
	if err != nil {
		panic(err)
	}
	return buf
}

// net.Addr interface for clientID.

func (addr clientID) Network() string {
	return "clientid"
}

func (addr clientID) String() string {
	return hex.EncodeToString(addr)
}

// packetErr is a tuple of ([]byte, error), as returned from a Read operation.
type packetErr struct {
	P   []byte
	Err error
}

// httpConn works like a connected TCP socket: it exchanges packets with a
// single web server, which is named by an HTTP URL.
type httpConn struct {
	serverURL string
	localAddr clientID
	// The "closed" channel is closed when the httpConn is closed.
	closed chan struct{}
	// Write and pollLoop feed any non-empty HTTP response bodies, or read
	// errors, into recvQueue. Read reads received packets from recvQueue.
	recvQueue chan packetErr
	// pollQueue is used to control the operation of pollLoop. Whenever
	// something is received on pollQueue, pollLoop resets its poll timer.
	pollQueue chan struct{}
}

// Create a new httpConn bound to the web server at serverURL.
func newHTTPConn(serverURL string) *httpConn {
	conn := &httpConn{
		serverURL: serverURL,
		localAddr: newClientID(),
		closed:    make(chan struct{}),
		recvQueue: make(chan packetErr, queueCapacity),
		pollQueue: make(chan struct{}),
	}
	go conn.pollLoop()
	return conn
}

// pollLoop sends a polling request to the server whenever there has not been an
// organic request made in a while. Sending a value on conn.pollQueue causes
// pollLoop to reset its timer and not poll for a little bit.
//
// The need to poll is incidentally why httpConn is bound to a single
// remote server, and Write does not accept varying addresses to indicate
// different servers.
func (conn *httpConn) pollLoop() {
	delay := 10 * time.Second
	delayTimer := time.NewTimer(delay)
	for {
		select {
		case <-conn.closed:
			delayTimer.Stop()
			return
		case <-conn.pollQueue:
			if !delayTimer.Stop() {
				<-delayTimer.C
			}
			// Data was received recently, so reset the poll timer.
			delay = 50 * time.Millisecond
		case <-delayTimer.C:
			delay *= 2
			if delay > 10*time.Second {
				delay = 10 * time.Second
			}
			// Send an empty polling request.
			_, err := conn.Write(nil)
			if err != nil {
				conn.recvQueue <- packetErr{nil, err}
			}
		}
		delayTimer.Reset(delay)
	}
}

var errClosedPacketConn = errors.New("closed conn")

// Read a packet from the remote server.
func (conn *httpConn) Read(p []byte) (int, error) {
	select {
	case pe := <-conn.recvQueue:
		return copy(p, pe.P), pe.Err
	case <-conn.closed:
		return 0, &net.OpError{Op: "read", Net: conn.RemoteAddr().Network(), Source: conn.LocalAddr(), Addr: conn.RemoteAddr(), Err: errClosedPacketConn}
	}
}

// Write a packet to the remote server.
func (conn *httpConn) Write(p []byte) (int, error) {
	select {
	case <-conn.closed:
		return 0, &net.OpError{Op: "write", Net: conn.RemoteAddr().Network(), Source: conn.LocalAddr(), Addr: conn.RemoteAddr(), Err: errClosedPacketConn}
	default:
	}

	// Inhibit polling because we're about to send a real request.
	if len(p) > 0 {
		select {
		case conn.pollQueue <- struct{}{}:
		default:
		}
	}

	// One request per packet.
	req, err := http.NewRequest("POST", conn.serverURL, bytes.NewReader(p))
	if err != nil {
		return 0, err
	}
	req.Header.Set("Content-Type", "application/octet-stream")
	// The clientID could alternatively be prepended to the request body.
	req.Header.Set("Client-Id", conn.LocalAddr().String())
	resp, err := http.DefaultTransport.RoundTrip(req)
	if err != nil {
		return 0, err
	}
	defer resp.Body.Close()

	// The response body may contain one server packet.
	buf, err := ioutil.ReadAll(io.LimitReader(resp.Body, 2000))
	if err != nil {
		// Make sure errors enter the queue.
		conn.recvQueue <- packetErr{buf, err}
	} else if len(buf) > 0 {
		// OK to drop non-error packets.
		select {
		case conn.recvQueue <- packetErr{buf, err}:
		default:
		}
	}

	return len(p), nil
}

func (conn *httpConn) Close() error {
	select {
	case <-conn.closed:
		return &net.OpError{Op: "close", Net: conn.LocalAddr().Network(), Addr: conn.LocalAddr(), Err: errClosedPacketConn}
	default:
		// Make pollLoop return.
		close(conn.closed)
		return nil
	}
}

func (conn *httpConn) LocalAddr() net.Addr {
	return conn.localAddr
}

func (conn *httpConn) SetDeadline(t time.Time) error {
	return errors.New("not implemented")
}

func (conn *httpConn) SetReadDeadline(t time.Time) error {
	return errors.New("not implemented")
}

func (conn *httpConn) SetWriteDeadline(t time.Time) error {
	return errors.New("not implemented")
}

// A pseudo-net.Addr representing a remote web server URL.
type urlAddr string

func (addr urlAddr) Network() string {
	return "url"
}

func (addr urlAddr) String() string {
	return string(addr)
}

func (conn *httpConn) RemoteAddr() net.Addr {
	return urlAddr(conn.serverURL)
}

func handleLocalConn(conn *net.TCPConn, assoc *sctp.Association, streamID uint16) error {
	// https://tools.ietf.org/html/rfc4960#section-14.4: "value 0 ... is
	// reserved by SCTP to indicate an unspecified payload protocol
	// identifier in a DATA chunk."
	stream, err := assoc.OpenStream(streamID, 0)
	if err != nil {
		return err
	}
	log.Printf("stream ID %v", stream.StreamIdentifier())

	var wg sync.WaitGroup
	wg.Add(2)
	go func() {
		defer wg.Done()
		_, err := io.Copy(conn, stream)
		if err != nil {
			log.Printf("recv error: %v", err)
		}
		log.Printf("recv done")
		err = conn.CloseWrite()
		if err != nil {
			log.Printf("conn shutdown error: %v", err)
		}
	}()
	go func() {
		defer wg.Done()
		_, err := io.Copy(stream, conn)
		if err != nil {
			log.Printf("send error: %v", err)
		}
		log.Printf("send done")
		err = stream.Close()
		if err != nil {
			log.Printf("stream close error: %v", err)
		}
	}()
	wg.Wait()

	return nil
}

func acceptLocalConns(ln *net.TCPListener, assoc *sctp.Association) error {
	var streamID uint16
	for {
		conn, err := ln.AcceptTCP()
		if err != nil {
			if err, ok := err.(*net.OpError); ok && err.Temporary() {
				log.Printf("temporary error in ln.Accept: %v", err)
				continue
			}
			return err
		}

		go func(streamID uint16) {
			defer conn.Close()
			err := handleLocalConn(conn, assoc, streamID)
			if err != nil {
				log.Printf("error in handleLocalConn: %v", err)
			}
		}(streamID)
		streamID++
	}
}

func run(listenAddr, serverURL string) error {
	ln, err := net.Listen("tcp", listenAddr)
	if err != nil {
		return err
	}
	defer ln.Close()

	conn := newHTTPConn(serverURL)
	defer conn.Close()

	config := sctp.Config{
		NetConn:       conn,
		LoggerFactory: logging.NewDefaultLoggerFactory(),
	}
	assoc, err := sctp.Client(config)
	if err != nil {
		return err
	}
	defer assoc.Close()

	return acceptLocalConns(ln.(*net.TCPListener), assoc)
}

func main() {
	log.SetFlags(log.LstdFlags | log.LUTC)

	flag.Parse()
	if flag.NArg() != 2 {
		fmt.Fprintf(os.Stderr, "usage: %s LISTENADDR URL\n", os.Args[0])
		os.Exit(1)
	}
	listenAddr := flag.Arg(0)
	serverURL := flag.Arg(1)

	err := run(listenAddr, serverURL)
	if err != nil {
		log.Println(err)
		os.Exit(1)
	}
}
