import argparse
import internetarchive as ia
import io
import os
import struct
from concurrent.futures import ThreadPoolExecutor, as_completed
from warcio.archiveiterator import ArchiveIterator
import zstandard as zstd
import re


def list_warcs(collection_id):
    warcs = []

    for result in ia.search_items(f"collection:{collection_id}"):
        item = ia.get_item(result['identifier'])
        for file in item.files:
            file_name = file['name']
            if file_name.endswith('.warc.gz') or file_name.endswith('.warc') or file_name.endswith('.warc.zst'):
                warcs.append((result['identifier'], file_name))

    return warcs


def download_warc(collection_id, identifier, file_name, output_dir):
    output_path = os.path.join(output_dir, os.path.basename(file_name))
    ia.download(identifier, files=[file_name], destdir=output_dir, no_directory=True, retries=3)
    return output_path


def read_skippable_frame(stream):
    frame_header = stream.read(4)
    if len(frame_header) != 4:
        return None
    frame_size = struct.unpack('<I', frame_header)[0]

    if frame_size & 0x80000000:
        frame_data = stream.read(frame_size & 0x7FFFFFFF)
        return frame_data
    else:
        stream.seek(-4, io.SEEK_CUR)
        return None


def extract_urls(file_path):
    urls = []

    # Regex pattern to match the desired URLs
    url_pattern = re.compile(r'https?://[^/]*zippyshare\.com(?::[0-9]*)?/[\S]*')

    with open(file_path, 'rb') as stream:
        if file_path.endswith('.warc.zst'):
            dictionary_data = read_skippable_frame(stream)
            dctx = zstd.ZstdDecompressor(dict_data=zstd.ZstdDecompressionDict(dictionary_data))
            stream = dctx.stream_reader(stream)

        for record in ArchiveIterator(stream):
            if record.rec_type == 'response':
                content_type = record.http_headers.get_header('Content-Type', '')
                if content_type.startswith('text/html'):
                    content = record.content_stream().read().decode('utf-8', errors='ignore')

                    # Find all URLs in the HTML content
                    found_urls = url_pattern.findall(content)

                    urls.extend(found_urls)

    return urls


def process_warcs(collection_id, output_dir, max_workers=4):
    warcs = list_warcs(collection_id)

    def download_and_process(identifier, file_name, output_dir):
        warc_file = download_warc(collection_id, identifier, file_name, output_dir)
        urls = extract_urls(warc_file)
        os.remove(warc_file)
        return urls

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        tasks = {executor.submit(download_and_process, identifier, file_name, output_dir): (identifier, file_name) for
                 identifier, file_name in warcs}

        with open('output', 'w') as output_file:
            for task in as_completed(tasks):
                urls = task.result()
                for url in urls:
                    output_file.write(f"{url}\n")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Download and process WARC files from the Internet Archive')
    parser.add_argument('collection_id', help='The ID of the collection to process')
    parser.add_argument('--output_dir', default='warcs', help='The directory to store the downloaded WARC files')

    args = parser.parse_args()

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    process_warcs(args.collection_id, args.output_dir)