#!/usr/bin/env python3

# Converts relative URL references (in, e.g., <img src="..."> and
# <link href="..."> to data: URLs.

import base64
import html.parser
import getopt
import pathlib
import sys
import urllib.parse

MEDIATYPES = {
    ".png": "image/png",
}

TO_FILTER = {
    "img": set(("src",)),
    "link": set(("href",)),
}

def convert_to_data_uri(base_uri, uri):
    u = urllib.parse.urlparse(urllib.parse.urljoin(base_uri, uri))
    if u.scheme == "data":
        return uri
    elif u.scheme == "file" and not u.netloc:
        path = pathlib.Path(urllib.parse.unquote(u.path))
        with open(path, "rb") as f:
            mediatype = MEDIATYPES[path.suffix]
            return f"data:{mediatype};base64," + base64.b64encode(f.read()).decode()
    else:
        raise ValueError(f"unhandled: {ref!r}")

class Parser(html.parser.HTMLParser):
    def __init__(self, in_file, out_file, base_uri):
        super().__init__()
        self.in_file = in_file
        self.out_file = out_file
        self.base_uri = base_uri
        self.anything_handler = None
        self.register_on_anything_event_handler(Parser.set_pos)

    def output_between_positions(self, pos1, pos2):
        line1, col1 = pos1
        line2, col2 = pos2

        # Skip up to line1.
        self.in_file.seek(0)
        for lineno in range(1, line1):
            self.in_file.readline()
        # Read line1 itself.
        line = self.in_file.readline()
        if line1 == line2:
            # If line1 and line2 are the same, output a slice of the line.
            self.out_file.write(line[col1:col2])
        else:
            # Output everything after col1 in line1.
            self.out_file.write(line[col1:])
            # Output complete lines up to the beginning of line2.
            for lineno in range(line1 + 1, line2):
                self.out_file.write(self.in_file.readline())
            # Read line2 itself.
            line = self.in_file.readline()
            # Output everything up to col2 in line2.
            self.out_file.write(line[:col2])

    def set_pos(self):
        self.pos = self.getpos()

    def filter_start_tag(self, tag, attrs, end):
        self.output_between_positions(self.pos, self.getpos())
        self.pos = self.getpos()
        filter_attrs = TO_FILTER.get(tag)
        if filter_attrs is not None:
            self.out_file.write(f"<{html.escape(tag)}")
            for key, value in attrs:
                if key in filter_attrs:
                    value = convert_to_data_uri(self.base_uri, value)
                self.out_file.write(f' {html.escape(key)}="{html.escape(value)}"')
            self.out_file.write(f" />" if end else f">")
            # Update the pos for the next output_between_positions immediately
            # after this start or start+end tag.
            self.register_on_anything_event_handler(Parser.set_pos)

    def register_on_anything_event_handler(self, fn):
        if self.anything_handler is not None:
            raise ValueError(self.anything_handler)
        self.anything_handler = fn

    def process_events(self):
        if self.anything_handler is not None:
            fn = self.anything_handler
            self.anything_handler = None
            fn(self)

    def handle_starttag(self, tag, attrs):
        self.process_events()
        self.filter_start_tag(tag, attrs, False)

    def handle_startendtag(self, tag, attrs):
        self.process_events()
        self.filter_start_tag(tag, attrs, True)

    def handle_endtag(self, tag):
        self.process_events()

    def handle_data(self, data):
        self.process_events()

    def handle_entityref(self, name):
        self.process_events()

    def handle_charref(self, name):
        self.process_events()

    def handle_comment(self, data):
        self.process_events()

    def handle_decl(self, decl):
        self.process_events()

    def handle_pi(self, data):
        self.process_events()

    def finish(self):
        self.output_between_positions(self.pos, self.getpos())

def process(in_file, out_file, base_uri):
    data = in_file.read()
    p = Parser(in_file, out_file, base_uri)
    p.feed(data)
    p.finish()

_, (input_path,) = getopt.gnu_getopt(sys.argv[1:], "")
base_uri = pathlib.Path(input_path).absolute().as_uri()
with open(input_path) as in_file:
    process(in_file, sys.stdout, base_uri)
