// This is the browser half of the fetch-rpc WebExtension, which exposes the
// browser's fetch API (https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API)
// on a socket. A wholly in-browser WebExtension cannot open sockets by itself,
// so this program opens a socket to receive fetch parameters and pass them into
// the browser.
//
// This program does minimal syntax checking. Apart from ensuring that the
// messages are well-formed JSON with the correct top-level properties, and
// inspecting the ID, it doesn't care about the contents of messages. It's the
// responsibility of the browser half of the WebExtension to check everything
// else.

package main

import (
	"crypto/rand"
	"encoding/binary"
	"encoding/hex"
	"encoding/json"
	"fmt"
	"io"
	"math"
	"net"
	"os"
	"os/signal"
	"sync"
	"syscall"
	"time"
)

const (
	// We will listen on this local port to receive fetch specifications
	// from the client program.
	listenAddr = "127.0.0.1:9901"

	// How long we'll wait, after sending a request to the browser, for the
	// browser to send back a result. This is meant to be generous and
	// doesn't replace the browser's own request timeout. Its purpose is to
	// allow reclaiming memory in case the browser somehow drops a request
	// spec.
	waitForBrowserTimeout = 5 * time.Minute
)

// The result of a call to fetch: Response if the promise is fulfilled; or Error
// if the promise is rejected.
type fetchResult struct {
	Response interface{} `json:"response,omitempty"`
	Error    interface{} `json:"error,omitempty"`
}

// We receive multiple (possibly concurrent) connections over the listening
// socket, and we must multiplex all their requests/results over the single
// shared stdio stream to the browser. When the fetch function sends a request
// to the browser, it tags the request with a random ID, creates a channel on
// which to receive the result, and stores the ID–channel mapping in resultMap.
// When inFromBrowserLoop receives a result from the browser, it will be tagged
// with an ID. inFromBrowserLoop looks up the channel in resultMap using the ID,
// and sends the result back on the channel it finds.
var resultMap = make(map[string]chan<- *fetchResult)
var resultMapLock sync.Mutex

// A message that we send to the browser over stdout.
type webExtensionRequest struct {
	ID      string      `json:"id"`
	Request interface{} `json:"request"`
}

// A message that we receive from the browser over stdin.
type webExtensionResult struct {
	ID string `json:"id"`
	fetchResult
}

// Receive a request specification (over stdio from the client program). We take
// the first JSON object and ignore anything that follows.
func recvRequest(r io.Reader) (interface{}, error) {
	var request interface{}
	err := json.NewDecoder(r).Decode(&request)
	return request, err
}

// Send a result (over stdio to the client program).
func sendResult(w io.Writer, result interface{}) error {
	return json.NewEncoder(w).Encode(&result)
}

// Receive an encoded WebExtension message.
// https://developer.mozilla.org/en-US/docs/Mozilla/Add-ons/WebExtensions/Native_messaging#App_side
func recvWebExtensionMessage(r io.Reader, message interface{}) error {
	var length uint32
	err := binary.Read(r, NativeEndian, &length)
	if err != nil {
		return err
	}
	encoded := make([]byte, length)
	_, err = io.ReadFull(r, encoded)
	if err != nil {
		return err
	}
	return json.Unmarshal(encoded, &message)
}

// Send an encoded WebExtension message.
// https://developer.mozilla.org/en-US/docs/Mozilla/Add-ons/WebExtensions/Native_messaging#App_side
func sendWebExtensionMessage(w io.Writer, message interface{}) error {
	encoded, err := json.Marshal(&message)
	if err != nil {
		return err
	}
	length := len(encoded)
	if uint64(length) > math.MaxUint32 {
		return fmt.Errorf("WebExtension message is too long to represent: %d", length)
	}
	err = binary.Write(w, NativeEndian, uint32(length))
	if err != nil {
		return err
	}
	_, err = w.Write(encoded)
	return err
}

// Read a responseSpec from the socket and wrap it in a webExtensionRequest,
// tagging it with a random ID. Register the ID in resultMap and forward the
// webExtensionRequest to the browser. Wait for the browser to send back a
// webExtensionResult (which actually happens in inFromBrowserLoop--that
// function uses the ID to find this goroutine again). Return the fetchResult
// that was contained in the webExtensionResult.
func fetch(request interface{}, outToBrowserChan chan<- *webExtensionRequest) (*fetchResult, error) {
	// Generate an ID that will allow us to match a response to this request.
	idRaw := make([]byte, 8)
	_, err := rand.Read(idRaw)
	if err != nil {
		return nil, err
	}
	id := hex.EncodeToString(idRaw)

	// This is the channel over which inFromBrowserLoop will send the
	// result. Register it in resultMap so inFromBrowserLoop can send the
	// result back to us.
	resultChan := make(chan *fetchResult)
	resultMapLock.Lock()
	resultMap[id] = resultChan
	resultMapLock.Unlock()

	// Send the tagged request to the browser.
	outToBrowserChan <- &webExtensionRequest{
		ID:      id,
		Request: request,
	}

	// Now wait for the browser to send the response back to us.
	// inFromBrowserLoop will find the proper channel by looking up the ID
	// in resultMap.
	var result *fetchResult
	timeout := time.NewTimer(waitForBrowserTimeout)
	select {
	case result = <-resultChan:
		timeout.Stop()
	case <-timeout.C:
		// But don't wait forever, so as to allow reclaiming memory in
		// case of a malfunction elsewhere.
		resultMapLock.Lock()
		delete(resultMap, id)
		resultMapLock.Unlock()
		err = fmt.Errorf("timed out waiting for browser to reply")
	}
	return result, err
}

// A type to represent errors that originate in this program itself, as opposed
// to ones that we relay from the browser. Make sure the error message appears
// in a "message" property, in order to match the contract of the Error type in
// JavaScript.
// https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Error/prototype#Standard_properties
type errorObj struct {
	Message string `json:"message"`
}

// Handle a socket connection, which is used for one request–response roundtrip
// through the browser. Delegates the real work to roundTrip, which reads the
// requestSpec from the socket and sends it through the browser. Here, we wrap
// any error from roundTrip in an "error" response and send the response back on
// the socket.
func handleConn(conn net.Conn, outToBrowserChan chan<- *webExtensionRequest) error {
	defer conn.Close()

	request, err := recvRequest(conn)
	if err != nil {
		return err
	}

	result, err := fetch(request, outToBrowserChan)
	if err != nil {
		result = &fetchResult{Error: &errorObj{Message: err.Error()}}
	}

	// Encode the response send it back out over the socket.
	return sendResult(conn, result)
}

// Receive socket connections and dispatch them to handleConn.
func acceptLoop(ln net.Listener, outToBrowserChan chan<- *webExtensionRequest) error {
	for {
		conn, err := ln.Accept()
		if err != nil {
			if err, ok := err.(net.Error); ok && err.Temporary() {
				continue
			}
			return err
		}
		go func() {
			err := handleConn(conn, outToBrowserChan)
			if err != nil {
				fmt.Fprintln(os.Stderr, "handling socket request:", err)
			}
		}()
	}
}

// Read messages from the browser and send them to the channel that corresponds
// to the original request. This is the only function allowed to read from
// stdin.
func inFromBrowserLoop() error {
	for {
		var result webExtensionResult
		err := recvWebExtensionMessage(os.Stdin, &result)
		// An error here means the stdin stream may be desynced, so give
		// up.
		if err != nil {
			return err
		}

		// Look up what channel, previously registered in resultMap by
		// fetch, should receive the response.
		resultMapLock.Lock()
		responseChan, ok := resultMap[result.ID]
		delete(resultMap, result.ID)
		resultMapLock.Unlock()

		if ok {
			responseChan <- &result.fetchResult
			close(responseChan)
		}
		// If !ok, it means that either the browser made up an ID that
		// we never sent it, or (more likely) it took too long and fetch
		// stopped waiting. In this case there's nothing to do.
	}
}

// Read messages from outToBrowserChan and send them to the browser. This is the
// only function allowed to write to stdout.
func outToBrowserLoop(outToBrowserChan <-chan *webExtensionRequest) error {
	for request := range outToBrowserChan {
		err := sendWebExtensionMessage(os.Stdout, request)
		if err != nil {
			// An error here means the browser may or may not have
			// received our message, and the stdout stream may be
			// desynced, so give up.
			return err
		}
	}
	return nil
}

func main() {
	ln, err := net.Listen("tcp", listenAddr)
	if err != nil {
		fmt.Fprintln(os.Stderr, err)
		os.Exit(1)
	}
	defer ln.Close()

	outToBrowserChan := make(chan *webExtensionRequest)
	signalChan := make(chan os.Signal)
	errChan := make(chan error)

	// Goroutine that handles new socket connections.
	go func() {
		errChan <- acceptLoop(ln, outToBrowserChan)
	}()

	// Goroutine that writes WebExtension messages to stdout.
	go func() {
		errChan <- outToBrowserLoop(outToBrowserChan)
	}()

	// Goroutine that reads WebExtension messages from stdin.
	go func() {
		err := inFromBrowserLoop()
		if err == io.EOF {
			// EOF is not an error to display.
			err = nil
		}
		errChan <- err
	}()

	// We quit when we receive a SIGTERM, or when our stdin is closed, or
	// some unrecoverable error happens.
	// https://developer.mozilla.org/en-US/docs/Mozilla/Add-ons/WebExtensions/Native_messaging#Closing_the_native_app
	signal.Notify(signalChan, syscall.SIGTERM)
	select {
	case <-signalChan:
	case err := <-errChan:
		if err != nil {
			fmt.Fprintln(os.Stderr, err)
		}
	}
}
