#!/usr/bin/env python3
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3, or (at your option)
# any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, see <https://www.gnu.org/licenses/>.

import locale
import argparse
import os
import io
import re
import shutil
from stat import S_ISDIR, S_ISREG
from pathlib import Path
import sys
import tempfile
import warnings
from warnings import warn

from chardet.universaldetector import UniversalDetector


VERSION = "1.14"

def simple_warning(msg, cat, filename, lineno, file, line):
    print(f"\n{parser.prog}: {msg}", file=file or sys.stderr)
warnings.showwarning = simple_warning

def die(code, msg):
    warn(msg)
    sys.exit(code)


# Adapted from: https://stackoverflow.com/questions/24528278/stream-multiple-files-into-a-readable-object-in-python
# Note: the original code is licensed under CC-BY-SA 3.0, which is
# upwards-compatible with 4.0, and hence compatible with GPLv3.
class ChainStream(io.RawIOBase):
    """
    Chain an iterable of streams together into a single buffered stream.
    Usage:
        def generate_open_file_streams():
            for file in filenames:
                yield open(file, 'rb')
        f = io.BufferedReader(ChainStream(generate_open_file_streams()))
        f.read()
    """
    def __init__(self, streams):
        self.leftover = b''
        self.stream_iter = iter(streams)
        try:
            self.stream = next(self.stream_iter)
        except StopIteration:
            self.stream = None

    def readable(self):
        return True

    def _read_next_chunk(self, max_length):
        # Return 0 or more bytes from the current stream, first returning all
        # leftover bytes. If the stream is closed returns b''
        if self.leftover:
            return self.leftover
        elif self.stream is not None:
            return self.stream.read(max_length)
        else:
            return b''

    def readinto(self, b):
        buffer_length = len(b)
        chunk = self._read_next_chunk(buffer_length)
        while len(chunk) == 0:
            # move to next stream
            if self.stream is not None:
                self.stream.close()
            try:
                self.stream = next(self.stream_iter)
                chunk = self._read_next_chunk(buffer_length)
            except StopIteration:
                # No more streams to chain together
                self.stream = None
                return 0  # indicate EOF
        output, self.leftover = chunk[:buffer_length], chunk[buffer_length:]
        b[:len(output)] = output
        return len(output)


def slurp(filename):
    """Read a file into a string, aborting on error."""
    try:
        return Path(filename).read_text()
    except IOError as e:
        die(os.EX_DATAERR, f"Could not read file {filename}")


def unescape(s):
    regex = re.compile(r'\\([0-7]{1,3}|x[0-9a-fA-F]{2}|[nrtvafb\\])')
    return regex.sub(lambda match: eval('"%s"' % match.group()), s)

def casetype(string):
    # Starts with lower case
    case = 0

    # Capitalized?
    if len(string) >= 1 and string[0].isupper():
        case = 1

        # All upper case?
        all_upper = True
        for i in range(1, len(string)):
            if not string[i].isupper():
                all_upper = False
                break
        if all_upper:
            case = 2

    return case

def caselike(model, string):
    if len(string) > 0:
        case = casetype(model)
        if case == 1:
            string = string[0].upper() + string[1:]
        elif case == 2:
            string = string.upper()
    return string

def replace(instream, outstream, regex, split_regex, before, after, encoding, filename):
    patlen = len(before)
    sum = 0

    tonext = u''
    retry_prefix = b''
    while True:
        block = retry_prefix + instream.read(io.DEFAULT_BUFFER_SIZE)
        if len(block) == 0:
            break

        try:
            err = None
            block = block.decode(encoding=encoding)
            retry_prefix = b''
        except ValueError as e:
            # Try carrying invalid input over to next iteration in case it's
            # just incomplete
            err = e
            if e.start > 0:
                retry_prefix = block[e.start:]
                try:
                    block = block[:e.start].decode(encoding=encoding)
                    err = None
                except ValueError as e:
                    err = e
        finally:
            if err is not None:
                if isinstance(err, UnicodeError):
                    warn(f"{filename}: decoding error ({err.reason})")
                else:
                    warn("Decoding error")
                warn("You can specify the encoding with --encoding")
                return 0

        parts = split_regex.split(tonext + block)
        sum += len(parts) // (1 + split_regex.groups)
        tonext = parts[-1] or u''

        results = []
        for i in range(0, len(parts) - split_regex.groups, 1 + split_regex.groups):
            results.append(parts[i])
            if parts[i + 1] != '':
                replace = regex.sub(after, parts[i + 1])
                if args.ignore_case == "match":
                    replace = caselike(parts[i + 1], replace)
                results.append(replace)

        joined_parts = ''.join(results)
        outstream.write(joined_parts.encode(encoding=encoding))

    outstream.write(tonext.encode(encoding=encoding))

    return sum


# Create command line argument parser.
parser = argparse.ArgumentParser(description="Search and replace text in files.",
                                 formatter_class=argparse.RawDescriptionHelpFormatter)
parser.add_argument('--version', action='version',
                    version="%(prog)s " + VERSION + '''
Copyright (C) 2018-2022 Reuben Thomas <rrt@sc3d.org>
Copyright (C) 2017 Jochen Kupperschmidt <homework@nwsnet.de>
Copyright (C) 2016 Kevin Coyner <kcoyner@debian.org>
Copyright (C) 2004-2005 Göran Weinholt <weinholt@debian.org>
Copyright (C) 2004 Christian Häggström <chm@c00.info>

%(prog)s comes with ABSOLUTELY NO WARRANTY.
You may redistribute copies of %(prog)s under the terms of the
GNU General Public License.
For more information about these matters, see the file named COPYING.''')

parser.add_argument("--encoding", metavar="ENCODING",
                    help="specify character set encoding")

parser.add_argument("-E", "--extended-regex",
                    action="store_true",
                    help="use extended regular expression module `regex'")

parser.add_argument("-i", "--ignore-case",
                    action="store_true",
                    help="search case-insensitively")

parser.add_argument("-m", "--match-case",
                    action="store_const",
                    dest="ignore_case",
                    const="match",
                    help="ignore case when searching, but try to match case of replacement to case of original, either capitalized, all upper-case, or mixed")

parser.add_argument("-w", "--whole-words",
                    action="store_true",
                    help="whole words (OLD-TEXT matches on word boundaries only)")

parser.add_argument("-b", "--backup",
                    action="store_true",
                    help="rename original FILE to FILE~ before replacing")

parser.add_argument("-q", "--quiet",
                    action="store_true",
                    help="quiet mode")

parser.add_argument("-v", "--verbose",
                    action="store_true",
                    help="verbose mode")

parser.add_argument("-s", "--dry-run",
                    action="store_true",
                    help="simulation mode")

parser.add_argument("-e", "--escape",
                    action="store_true",
                    help="expand escapes in OLD-TEXT and NEW-TEXT [deprecated]")

parser.add_argument("-F", "--fixed-strings",
                    action="store_true",
                    help="treat OLD-TEXT and NEW-TEXT as fixed strings, not regular expressions")

parser.add_argument("--files",
                    action="store_true",
                    help="OLD-TEXT and NEW-TEXT are file names to read patterns from")

parser.add_argument("-x", "--glob",
                    action="append",
                    help="modify only files matching the given glob (may be given more than once)")

parser.add_argument("-R", "--recursive",
                    action="store_true",
                    help="search recursively")

parser.add_argument("-p", "--prompt",
                    action="store_true",
                    help="prompt before modifying each file")

parser.add_argument("-f", "--force",
                    action="store_true",
                    help="ignore errors when trying to preserve attributes")

parser.add_argument("-d", "--keep-times",
                    action="store_true",
                    help="keep the modification times on modified files")

parser.add_argument('old_str', metavar='OLD-TEXT')
parser.add_argument('new_str', metavar='NEW-TEXT')
parser.add_argument('file', metavar='FILE', nargs='*',
                    help="`-' or no FILE argument means standard input")

args = parser.parse_args()
files = args.file

# If no --glob arguments given, use a match-all glob
if args.glob == None:
    args.glob = ['*']

# Use `regex' module if desired
if args.extended_regex:
    try:
        import regex as re
    except:
        die(1, "Could not load `regex' module")

# If no files given, assume stdin
if len(files) == 0:
    if args.recursive:
        die(1, "Cannot use --recursive with no file arguments!")
    files = ["-"]
else:
    expanded_files = []
    for file in files:
        for glob in args.glob:
            if args.recursive and os.path.isdir(file):
                expanded_files += Path(file).rglob(glob)
            elif Path(file) in Path(os.path.dirname(file)).glob(glob):
                expanded_files.append(file)
    files = expanded_files
    if len(files) == 0:
        die(1, "The given filename patterns did not match any files!")

# Get old and new text patterns
if args.files:
    old_str = slurp(args.old_str)
    new_str = slurp(args.new_str)
else:
    old_str = args.old_str
    new_str = args.new_str

# Tell the user what is going to happen
if not args.quiet:
    warn("{} \"{}\" with \"{}\" ({}; {})".format(
        "Simulating replacement of" if args.dry_run else "Replacing",
        old_str,
        new_str,
        "ignoring case" if args.ignore_case == True else
        ("matching case" if args.ignore_case == "match" else "case sensitive"),
        "whole words only" if args.whole_words else "partial words matched",
    ))

if args.dry_run and not args.quiet:
    warn("The files listed below would be modified in a replace operation")

encoding = None
if args.encoding:
    encoding = args.encoding

if args.escape:
    old_str = unescape(old_str)
    new_str = unescape(new_str)

if args.fixed_strings:
    old_str = re.escape(old_str)
    new_str = new_str.replace('\\', r'\\')

regex_str = old_str
if args.whole_words:
    regex_str = r"\b" + regex_str + r"\b"
# Call re.compile so we get an error if the regex is invalid, & count groups.
opts = re.M + (re.I if args.ignore_case else 0)
try:
    regex = re.compile(regex_str, opts)
    split_regex = re.compile(f"({regex_str})", opts)
except re.error as e:
    die(1, f"Bad regex {old_str} ({e.msg})")

total_files = 0
matched_files = 0
total_matches = 0
for filename in files:
    perms = None
    if filename == "-":
        filename = "standard input"
        # Access stdin and stdout as binary.
        f = sys.stdin.buffer
        o = sys.stdout.buffer
        tmp_path = None
    else:
        # Check `filename` is a regular file, and get its permissions
        try:
            perms = os.lstat(filename)
        except OSError as e:
            warn(f"Skipping {filename}: unable to read permissions; error: {e}")
            continue
        if S_ISDIR(perms.st_mode):
            if args.verbose:
                warn(f"Skipping directory {filename}")
            continue
        elif not S_ISREG(perms.st_mode):
            warn(f"Skipping: {filename} (not a regular file)")
            continue

        # Open the input file
        try:
            f = open(filename, "rb")
        except IOError as e:
            warn(f"Skipping {filename}: cannot open for reading; error: {e}")
            continue

        # Create the output file
        try:
            o, tmp_path = tempfile.mkstemp("", ".tmp.")
            o = os.fdopen(o, "wb")
        except OSError as e:
            warn(f"Skipping {filename}: cannot create temp file; error: {e}")
            continue

        # Set permissions and owner
        if perms is not None:
            try:
                os.chown(tmp_path, perms.st_uid, perms.st_gid)
                os.chmod(tmp_path, perms.st_mode)
            except OSError as e:
                warn(f"Unable to set attributes of {filename}; error: {e}")
                if args.force:
                    warn("New file attributes may not match!")
                else:
                    warn(f"Skipping {filename}!")
                    os.unlink(tmp_path)
                    continue

    total_files += 1

    # If no encoding specified, reset guess for each file
    if not args.encoding:
        encoding = None

    if args.verbose and not args.dry_run:
        warn(f"Processing: {filename}")

    # If we don't have an explicit encoding, guess
    block = b''
    if encoding is None:
        detector = UniversalDetector()
        scanned_bytes = 0
        # Scan at most 1MB, so we don't give up too soon, but don't slurp a
        # large file.
        while scanned_bytes < 1024*1024:
            next_block = f.read(io.DEFAULT_BUFFER_SIZE)
            if len(next_block) == 0: break
            scanned_bytes += len(next_block)
            block += next_block
            detector.feed(next_block)
            if detector.done: break
        f = io.BufferedReader(ChainStream([io.BytesIO(block), f]))

        detector.close()
        if detector.done:
            encoding = detector.result['encoding']
            if args.verbose:
                if encoding is not None:
                    warn(f"Guessed encoding '{encoding}'")
                else:
                    warn("Unable to guess encoding")
        if encoding is None:
            encoding = locale.getpreferredencoding(False)
            if args.verbose:
                warn(f"Could not guess encoding; using locale default '{encoding}'")

    # Do the actual work now
    matches = replace(f, o, regex, split_regex, old_str, new_str, encoding, filename)

    f.close()
    o.close()

    if matches == 0:
        if tmp_path is not None:
            os.unlink(tmp_path)
        continue
    else:
        matched_files += 1

    if args.dry_run:
        if tmp_path is None:
            fn = filename
        else:
            try:
                fn = os.path.realpath(filename)
            except OSError as e:
                fn = filename
            os.unlink(tmp_path)

        if not args.quiet:
            print(f"  {fn}", file=sys.stderr)

        total_matches += matches
        continue

    if args.prompt:
        print(f"\nSave \"{filename}\"? ([Y]/N) ", file=sys.stderr, end='')

        line = ""
        while line == "" or line[0] not in "Yy\nnN":
            line = input()

        if line[0] in "nN":
            print("Not saved", file=sys.stderr)
            if tmp_path is not None:
                os.unlink(tmp_path)
            continue

        print("Saved", file=sys.stderr)

    if tmp_path is not None:
        if args.backup:
            try:
                shutil.move(filename, filename + "~")
            except OSError as e:
                warn(f"Error renaming {filename} to {filename + '~'}; error: {e}")
                continue

        # Rename the file
        try:
            shutil.move(tmp_path, filename)
        except OSError as e:
            warn(f"Could not replace {tmp_path} with {filename}; error: {e}")
            os.unlink(tmp_path)
            continue

        # Restore the times
        if args.keep_times:
            try:
                os.utime(filename, (perms.st_atime, perms.st_mtime))
            except OSError as e:
                warn(f"Error setting timestamps of {filename}: {e}")

    total_matches += matches


# We're about to exit, give a summary
if not args.quiet:
    warn("{} matches {} in {} out of {} file{}".format(
        total_matches,
        "found" if args.dry_run else "replaced",
        matched_files,
        total_files,
        "s" if total_files != 1 else "",
    ))
