import datetime
import os
import re
import gzip
import warcio
import argparse
import time
import zstandard as zstd
from warcio.archiveiterator import ArchiveIterator
import io
from concurrent.futures import ProcessPoolExecutor

def read_skippable_frame(file):
    magic_number = file.read(4)
    if magic_number != b'\x28\xB5\x2F\xFD':
        return None
    frame_size = int.from_bytes(file.read(4), 'little')
    return file.read(frame_size)

def process_warc(args):
    file_path, output_file_path, pattern = args
    print(f"Processing file: {file_path}")
    counter = 0
    matches_buffer = []

    if file_path.endswith('.warc.gz'):
        warc_stream = gzip.open(file_path, 'rb')
    elif file_path.endswith('.warc.zst'):
        with open(file_path, 'rb') as raw_file:
            dict_data = read_skippable_frame(raw_file)
            dctx = zstd.ZstdDecompressor(dict_data=dict_data)
            remaining_data = raw_file.read()
            warc_stream = io.BytesIO(dctx.decompress(remaining_data))
    else:
        warc_stream = open(file_path, 'rb')

    with warc_stream:
        file_size = os.path.getsize(file_path)
        last_printed_progress = -2  # Initialize to -5 so that it prints at 0% progress

        for record in ArchiveIterator(warc_stream):
            if record.rec_type == 'response':
                content = record.content_stream().read().decode(errors='replace')
                matches = pattern.findall(content)
                for match in matches:
                    matches_buffer.append(match)
                    counter += 1

                    if counter % 100 == 0:
                        with open(output_file_path, 'a') as output_file:
                            output_file.write('\n'.join(matches_buffer) + '\n')
                        matches_buffer = []

            # Calculate and print progress every 2%
            progress = (warc_stream.tell() / file_size) * 100
            if progress - last_printed_progress >= 2:
                print(f"{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}:Progress on {file_path}: {progress:.2f}%")
                last_printed_progress = progress

    # Write the remaining matches
    if matches_buffer:
        with open(output_file_path, 'a') as output_file:
            output_file.write('\n'.join(matches_buffer) + '\n')

    print(f"Removing file: {file_path}")
    os.remove(file_path)  # Remove the WARC file after processing

def process_warcs_in_directory(warc_directory, output_file_path, pattern):
    while True:
        with ProcessPoolExecutor(max_workers=14) as executor:
            for file_name in os.listdir(warc_directory):
                if file_name.endswith(('.warc.gz', '.warc.zst')):
                    file_path = os.path.join(warc_directory, file_name)
                    executor.submit(process_warc, (file_path, output_file_path, pattern))

        print("Sleeping 60s")
        time.sleep(60)  # Wait for 60 seconds before checking the directory again

def main(args):
    regex_pattern = r'\S*imgur\S*'
    pattern = re.compile(regex_pattern)

    while True:
        process_warcs_in_directory(args.warc_directory, args.output_file_path, pattern)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Process WARC files and find matches for a regex pattern.')
    parser.add_argument('warc_directory', help='Path to the directory containing WARC files.')
    parser.add_argument('output_file_path', help='Path to the output file.')

    args = parser.parse_args()
    main(args)