#!/usr/bin/python3.3
# -*- coding: utf-8 -*-
#
# Copyright (C) 2011-2013 Julien Muchembled <jm@jmuchemb.eu>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 3 as
# published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import argparse, errno, grp, logging.handlers, os
import pwd, re, signal, sqlite3, stat, struct, subprocess, sys, syslog
from contextlib import contextmanager
from ctypes import CDLL, util as ctypes_util, get_errno, c_long
from hashlib import md5
from pickle import dumps, loads
from urllib.parse import splitport
from posix1e import ACL, ACL_USER, ACL_GROUP, delete_default, Entry, Permset

BLOCK_SIZE = 4096
PART_SIZE = BLOCK_SIZE * 256

logger = logging.getLogger(__name__)

def read_rpc(stdin):
  n = stdin.read(4)
  return n and stdin.read(*struct.unpack('!I', n))

def write_rpc(stdout, rpc):
  stdout.write(struct.pack('!I', len(rpc)) + rpc)

# NOTE: All paths are bytes instead of unicode.
#       This is required in order to handle paths with invalid characters.
#       For example, b'\x89' on a utf-8 filesystem:
#         sqlite3.connect(':memory:').execute('select ?', ('\udc89',))
#       raises UnicodeEncodeError
decode, encode = (lambda encoding:
  (lambda path: path.decode(encoding, 'surrogateescape'),
   lambda path: path.encode(encoding))
  )(sys.getfilesystemencoding())

class TODO(Exception): pass

_format_command_search = re.compile("[[\\s $({?*\\`#~';<>&|]").search
_format_command_escape = lambda s: "'%s'" % r"'\''".join(s.split("'"))
def format_command(*args):
  cmdline = []
  for v in args:
    if _format_command_search(v):
      v = _format_command_escape(v)
    cmdline.append(v)
  return ' '.join(cmdline)

def check_data(f, size, block_size):
  # XXX: should we cache reads ?
  while size:
    n = min(size, block_size)
    d = f.read(n)
    if len(d) != n:
      break
    yield md5(d).digest()
    size -= n

class UTIME_OMIT():
   def __divmod__(self, other):
     assert other == 1000000000
     return 0, (1 << 30) - 2
UTIME_OMIT = UTIME_OMIT()

libc = CDLL(ctypes_util.find_library('c'), use_errno=True)
def _fallocate():
  # We only use it to reduce fragmentation so this must not be replaced by
  # posix_fallocate (which emulates fallocate if kernel/fs does not support it).
  try:
    libc_fallocate = libc.fallocate
  except AttributeError:
    def _fallocate(fd, offset, length, keep_size=False):
      pass
  else:
    def _fallocate(fd, offset, length, keep_size=False):
      if libc_fallocate(fd, bool(keep_size), c_long(offset), c_long(length)):
        e = get_errno()
        if e not in (errno.ENOSYS, errno.EOPNOTSUPP):
          raise OSError(e, os.strerror(e))
  return _fallocate
_fallocate = _fallocate()

def dump_acl(**kw):
  acl = []
  for e in ACL(**kw):
    t = e.tag_type
    p = e.permset
    acl.append((t << 3 | p.read << 2 | p.write << 1 | p.execute,
                None if ACL_USER != t != ACL_GROUP else e.qualifier))
  return acl

def load_acl(acl):
  a = ACL()
  for t, q in acl:
    e = Entry(a)
    e.tag_type = t >> 3
    if q is not None:
      e.qualifier = q
    Permset(e).add(t & 7)
  return a

class Stat(object):
  # XXX: Consider splitting mode and moving type from value to key,
  #      in order to simplify code.

  NULL_KEY = None, None
  NULL_VALUE = (None,) * 8
  __slots__ = ('dev', 'ino', 'gid', 'mode', 'mtime_ns', 'rdev', 'size', 'uid',
               'acl', 'attr', 'blocks', 'blksize')

  def __init__(self, path):
    # XXX: see http://bugs.python.org/issue11457
    s = os.lstat(path)
    for k in self.__slots__[:-4]:
      setattr(self, k, getattr(s, 'st_' + k))
    if stat.S_ISDIR(self.mode):
      self.size = None
    else:
      self.blocks = s.st_blocks
      self.blksize = s.st_blksize
    try:
      x = dict((x, os.getxattr(path, x, follow_symlinks=False))
               for x in os.listxattr(path, follow_symlinks=False))
    except (AttributeError, OSError) as e:
      if isinstance(e, OSError) and e.errno != errno.ENOTSUP:
        raise
      a = d = x = None
    else:
      a = x.pop('system.posix_acl_access', None) and dump_acl(file=path)
      d = x.pop('system.posix_acl_default', None) and dump_acl(filedef=path)
      x.pop('trusted.SGI_ACL_FILE', None)
      x.pop('trusted.SGI_ACL_DEFAULT', None)
    self.acl = (a or d) and (a, d)
    self.attr = x or None

  def __eq__(self, other):
    for k in self.__slots__:
      if getattr(self, k) != getattr(other, k):
        return False
    return True

  @property
  def key(self):
    return tuple(getattr(self, k) for k in self.__slots__[:2])

  @property
  def value(self):
    return tuple(getattr(self, k) for k in self.__slots__[2:-2])

  @property
  def null_value(self):
    return (None, stat.S_IFMT(self.mode)) + (None,) * 6

  @classmethod
  def load(cls, key, value):
    self = object.__new__(cls)
    for k, v in zip(self.__slots__, key + value):
      setattr(self, k, v)
    return self


class RpcClient(object):

  def __init__(self, stdin, stdout, map_users=False):
    self.stdin = stdin
    self.stdout = stdout
    self._map_users = map_users

  def wait(self):
    r = loads(read_rpc(self.stdin))
    if isinstance(r, Exception):
      raise r
    return r

  def __getattr__(self, name):
    send = lambda rpc: write_rpc(self.stdout, dumps(rpc))
    if name in ('check', 'print', 'reverse'):
      if name == 'check':
        self.sync_meta # send pwd_grp if not done yet
      def rpc(*args, **kw):
        send((name, args, kw))
    else:
      if name == 'sync_meta' and self._map_users:
        # Send our passwd/group database only if remote will need it.
        send(('pwd_grp', (dict((x.pw_name, x.pw_uid) for x in pwd.getpwall()),
                          dict((x.gr_name, x.gr_gid) for x in grp.getgrall())),
                        {}))
        logger.debug('pwd_grp(...)')
        self.wait()
      def rpc(*args, **kw):
        send((name, args, kw))
        if len(args) > 1 and (args[1] is None or
          isinstance(args[1], (int, bytes))):
          logger.debug('%s(%r, %r)', name, args[0], args[1])
        else:
          logger.debug('%s(%r)', name, args[0])
    setattr(self, name, rpc)
    return rpc


class RpcSshClient(RpcClient):
  # NOTE: we use external 'ssh' command instead of 'paramiko' library because
  #       - it's faster
  #       - it seems the easiest way to have '.ssh/config' taken into account

  def __init__(self, host, command, *args, **kw):
    cmd = ['ssh']
    host, port = splitport(host)
    if port is not None:
      cmd += '-p', port
    cmd += host, command
    self._p = subprocess.Popen(cmd, bufsize=1,
                               stdin=subprocess.PIPE, stdout=subprocess.PIPE)
    super(RpcSshClient, self).__init__(self._p.stdout, self._p.stdin.raw,
                                       *args, **kw)


class Local(object):

  NULL_KEY = dumps(Stat.NULL_KEY, 2)
  NULL_VALUE = dumps(Stat.NULL_VALUE, 2)
  prealloc = False

  def __init__(self, root, db, rpc):
    assert isinstance(root, bytes), root
    self.root = root
    self.rpc = rpc
    # TODO: Add option to map other filesystems ?
    self.dev_map = {os.lstat(root).st_dev: None}
    self.con = sqlite3.connect(db, isolation_level=None)
    # I wish I could get fd of sqlite connection and pass it to os.fstat
    try:
      self.db_key = None if db == ':memory:' else self.stat(db).key
    except KeyError: # db is on another filesystem
      self.db_key = None
    self.con.execute("PRAGMA synchronous = OFF")
    self.con.text_factory = str
    # A row with null metadata means we are renaming/linking/removing a path:
    # 'inode' contains source path and 'path' is destination (null if removing).
    # This special meaning is part of recovering process in case synchronization
    # was interrupted.
    self.con.execute("""CREATE TABLE IF NOT EXISTS fssync (
      path blob PRIMARY KEY,
      inode blob,
      metadata blob,
      checked integer)
    """)
    self.con.execute("CREATE INDEX IF NOT EXISTS _fssync_i1 ON fssync(inode)")

  def __del__(self):
    self.con.close()

  def __call__(self, filter=None, path_list=None, check=False, print0=False,
               prealloc=False):
    print = print0 and (self.rpc.print if type(self.rpc) is RpcClient # XXX
                                       else sys.stdout.buffer.write)
    if filter:
      self.set_filter(eval(filter, globals(), {}), print)
    self.prealloc = prealloc
    self.rollback()
    for p in path_list or (b'',):
      if check:
        for path in self.check(p):
          print and print(path + b'\0')
      else:
        self.sync(p)
        self.clean(p)

  def stat(self, path):
    s = Stat(path)
    s.dev = self.dev_map[s.dev]
    return s

  def rename(self, path, new_path, null_key=False):
    sql = self.con.execute
    minmax = self._minmax(path)
    sql("begin")
    try:
      if new_path:
        if null_key:
          sql("update fssync set inode=? where path=?", (self.NULL_KEY, path))
        sql("delete from fssync where path=?", (new_path,))
        sql("update fssync set path=? where path=?", (new_path, path))
        sql("update fssync set path=cast(?||substr(path,?) as blob), checked=0"
            " where ?<path and path<?", (new_path, 1+len(path)) + minmax)
      else:
        sql("delete from fssync where path=?", (path,))
        sql("delete from fssync where ?<path and path<?", minmax)
    except:
      self.con.rollback()
      raise
    self.con.commit()

  def backup(self, path, null_key=False):
    i = 0
    while True:
      backup = path + ('#fssync%u.bak' % i).encode()
      if not os.path.lexists(os.path.join(self.root, backup)):
        try:
          self.con.execute("insert into fssync values (?, ?, null, 0)",
                           (backup, path))
          break
        except sqlite3.IntegrityError:
          pass
      i += 1
    self.rpc.rename(path, backup)
    self.rpc.wait()
    self.rename(path, backup, null_key)

  def filter(self, path, stat):
    pass

  def set_filter(self, func, print=None):
    root = decode(self.root)
    self.filter = lambda path, stat: func(root, decode(path), stat) and (
      print and print(path + b'\0') or 1)

  _minmax = staticmethod((lambda a, b: lambda path: (path + a, path + b))
                         (os.sep.encode(), chr(ord(os.sep) + 1).encode()))

  def non_empty(self, path, mode):
    return stat.S_ISDIR(mode) and self.con.execute(
      "select 1 from fssync where ?<path and path<?",
      self._minmax(path)).fetchone()

  def rollback(self):
    # There should not be more than 1 path to recover.
    sql = self.con.execute
    for path, other, metadata, _ in sql("select * from fssync where"
        " metadata is null or metadata=?", (self.NULL_VALUE,)).fetchall():
      if metadata:
        self.rpc.null_value(path)
        metadata = self.rpc.wait()
        if metadata:
          sql("update fssync set metadata=? where path=?", (metadata, path))
          continue
      else:
        self.rpc.rollback(path, other)
        self.rpc.wait()
      sql("delete from fssync where path=?", (path,))

  def sync(self, path):
    sql = self.con.execute
    p = os.path.join(self.root, path)
    while 1:
      try:
        s = self.stat(p)
      except (KeyError,           # other filesystem
              FileNotFoundError): # oops, it has just been deleted
        if not path:
          raise
        return
      if path and self.filter(path, s):
        return
      key = s.key
      if key == self.db_key:
        return
      fmt = stat.S_IFMT(s.mode)
      if fmt == stat.S_IFDIR:
        try:
          # it may raise if folder has just been deleted or replaced
          children = os.listdir(p)
        except OSError as e:
          if path:
            if e.errno == errno.ENOTDIR:
              continue
            if e.errno == errno.ENOENT:
              return
          raise
      elif fmt == stat.S_IFLNK:
        try:
          # it may raise if symlink has just been deleted or replaced
          target = os.readlink(p)
        except OSError as e:
          if path:
            if e.errno == errno.EINVAL:
              continue
            if e.errno == errno.ENOENT:
              return
          raise
      break
    r = sql('select * from fssync where path=?', (path,)).fetchall()
    if r:
      (_, inode, metadata, checked), = r
      if checked:
        return
      # Path already exists on remote and needs to be synchronized.
      old_key = loads(inode)
      r = Stat.load(old_key, loads(metadata))
      old_fmt = stat.S_IFMT(r.mode)
      if old_key != key:
        # Local inode has changed.
        if old_key == Stat.NULL_KEY or \
           sql('select path from fssync where inode=? and path!=?',
               (dumps(old_key, 2), path)).fetchone():
          # Remove remote ...
          if self.non_empty(path, r.mode):
            # ... but not directory contents, so delay removal.
            # Even if local inode is a directory, such cleanup is
            # required in case we want to rename from another path.
            self.backup(path, True)
          elif (# ... non-directory so that directory can be created
                fmt == stat.S_IFDIR if fmt != old_fmt
                # and hardlinks of same type.
                else old_key != Stat.NULL_KEY):
            # All other cases are handled by remote, which removes existing
            # inode automatically if its type differs or if we hardlink/rename
            # from another path.
            sql("update fssync set inode=null, metadata=null where path=?",
                (path,))
            self.rpc.remove(path)
            self.rpc.wait()
            sql("delete from fssync where path=?", (path,))
          elif path:
            logger.debug("let remote remove or reuse %r", path)
        else:
          self.backup(path)
        r = None
      elif fmt != stat.S_IFDIR and self.non_empty(path, r.mode):
        # Non-empty directory replaced by non-directory with same inode.
        self.backup(path, True)
      elif fmt != old_fmt:
        # Remote path will be automatically deleted and recreated.
        # If we get interrupted, remote file type will be unknown,
        # and we'll need to ask remote during recovery.
        sql("update fssync set metadata=? where path=?",
            (self.NULL_VALUE, path))
    elif self.rpc is None:
      # Assume remote is already up-to-date.
      r = s
    if not r:
      # Current path does not exist on remote or can't be reused.
      # Let's see if local inode is the result of a rename or hardlink ...
      c = sql('select * from fssync where inode=?', (dumps(key, 2),))
      r = c.fetchone()
      while r:
        backup, inode, metadata, _ = r
        r = Stat.load(loads(inode), loads(metadata))
        if fmt == stat.S_IFMT(r.mode):
          # ... probably (type is the same) so reuse it.
          # If we get interrupted, recovery phase must revert hardlink/rename.
          sql("insert or replace into fssync values (?, ?, null, 0)",
              (path, backup))
          # If target already exists (and is not a non-empty directory),
          # it will be automatically deleted first.
          try:
            if fmt == stat.S_IFDIR:
              self.rpc.rename(backup, path)
              self.rpc.wait()
              self.rename(backup, path)
            else:
              while True:
                self.rpc.link(backup, path)
                try:
                  self.rpc.wait()
                  break
                except FileNotFoundError:
                  logger.warning("missing %r on remote", backup)
                  # File doesn't exist anymore on remote.
                  # Maybe there's another hardlink to try.
                  backup = c.next()[0]
              sql("update fssync set inode=?, metadata=? where path=?",
                  (inode, metadata, path))
            break
          except StopIteration:
            # All hardlinks are missing on remote, so create new.
            logger.warning("... create new file for %r", path)
        r = None
      else:
        # If an error happens before the final SQL "update", and path disappears
        # before next sync, then we want to be sure it'll be removed on remote.
        sql("insert or replace into fssync values (?, ?, ?, 0)",
            (path, self.NULL_KEY, dumps(s.null_value, 2)))
    if fmt == stat.S_IFDIR:
      # Process contents before because this may alter permissions or
      # modification time.
      for name in children:
        self.sync(os.path.join(path, name))
    value = s.value
    if not r or value != r.value:
      # The first RPC will automatically remove existing inode if its type
      # differs. It will only fail if it tries to remove a non-empty directory.
      if fmt == stat.S_IFREG:
        try:
          f = open(p, 'rb')
        except (FileNotFoundError, IsADirectoryError):
          if path:
            return
          raise
        try:
          s2 = os.fstat(f.fileno())
          if key != (self.dev_map.get(s2.st_dev, ()), s2.st_ino) \
             or not stat.S_ISREG(s2.st_mode):
            return
          checked = 0
          unchecked = s.size
          if getattr(r, 'size', 0) < unchecked:
            sparse = s.blocks * 512 < unchecked
            if sparse or self.prealloc and BLOCK_SIZE < unchecked:
              self.rpc.truncate(path, unchecked, sparse)
              self.rpc.wait()
          while unchecked > 0:
            part_size = min(PART_SIZE, unchecked)
            self.rpc.check_data(path, checked, part_size, BLOCK_SIZE)
            local_hash_list = list(check_data(f, part_size, BLOCK_SIZE))
            remote_hash_list = self.rpc.wait()
            if not local_hash_list:
              break # local file truncated during check
            diff_list = []
            for i, h in enumerate(local_hash_list):
              if remote_hash_list[i] != h:
                f.seek(checked)
                diff_list.append((checked, f.read(BLOCK_SIZE)))
              checked += BLOCK_SIZE
              unchecked -= BLOCK_SIZE
            if diff_list:
              self.rpc.sync_data(path, diff_list)
              self.rpc.wait()
            f.seek(checked)
        except IndexError: # remote file is smaller
          if diff_list:
            self.rpc.sync_data(path, diff_list)
            self.rpc.wait()
          f.seek(checked)
          while unchecked > 0:
            self.rpc.sync_data(path, ((checked, f.read(PART_SIZE)),))
            checked += PART_SIZE
            unchecked -= PART_SIZE
            self.rpc.wait()
        finally:
          f.close()
      elif fmt == stat.S_IFLNK:
        self.rpc.symlink(path, target)
        self.rpc.wait()
      self.rpc.sync_meta(path, s) # creates the file/dir
      self.rpc.wait()             # (and parents) automatically
    # Do not use 'update' because we didn't insert anything if self.rpc is None.
    sql("insert or replace into fssync values (?, ?, ?, 1)",
        (path, dumps(key, 2), dumps(value, 2)))

  def clean(self, path):
    sql = self.con.execute
    if path:
      args = (path,) + self._minmax(path)
      path = "(path=? or ?<path and path<?) and"
    else:
      args = ()
      path = ''
    path_list = [x for x, in sql(
      "select path from fssync where %s checked=0" % path, args).fetchall()]
    if path_list:
      path_list.sort(reverse=True)
      self.rpc.removemany(path_list)
      self.rpc.wait()
      sql("delete from fssync where %s checked=0" % path, args)
    sql("update fssync set checked=0 where %s 1" % path, args)

  def check(self, path):
    if path:
      args = path, path + os.sep, path + chr(ord(os.sep) + 1)
      path = "where path=? or ?<path and path<?"
    else:
      args = ()
      path = ""
    fetchone = self.con.execute("select path, metadata from fssync "
                                + path + " order by path", args).fetchone
    while True:
      row_list = []
      try:
        for x in range(100):
          path, metadata = fetchone()
          row_list.append((path, loads(metadata)))
      except TypeError:
        if not row_list:
          break
      self.rpc.check(row_list)
      for path in self.rpc.wait():
        yield path

#  def update(self):
#    path = ''
#    while 1:
#      for path, inode, metadata, _ in self.con.execute(
#         'select * from fssync where ?<path and checked!=0 order by path',
#         (path,)).fetchall():
#        r = Stat.load(loads(inode), loads(metadata))
#        if self.filter(path, r):
#          self.clean(path)
#          break
#      else:
#        break


class Remote(object):

  _open_args = None
  _pwd = _grp = staticmethod(lambda id: id)

  def __init__(self, root):
    assert isinstance(root, bytes), root
    self.root = root

  def _open(self, path, mode='rb'):
    if self._open_args != (path, mode):
      self._close()
      self._open_file = open(path, mode)
      self._open_args = path, mode
    return self._open_file

  def _close(self):
    if self._open_args:
      self._open_file.close()
      del self._open_args, self._open_file

  def __call__(self, stdin, stdout):
    while 1:
      try:
        method, args, kw = loads(read_rpc(stdin))
      except EOFError:
        break
      if method == 'reverse':
        Local(self.root, args[0], RpcClient(stdin, stdout, args[1]))(
          *args[2:], **kw)
        break
      try:
        if method == 'print':
          sys.stdout.buffer.write(*args)
          continue
        if isinstance(args[0], bytes):
          args = (os.path.join(self.root, args[0]),) + args[1:]
        result = getattr(self, method)(*args, **kw)
      except Exception as e:
        logger.exception('%s(%r)', method, args[0])
        result = e
      write_rpc(stdout, dumps(result))

  def pwd_grp(self, p, g):
    p = dict((p.get(x.pw_name), x.pw_uid) for x in pwd.getpwall())
    g = dict((g.get(x.gr_name), x.gr_gid) for x in grp.getgrall())
    p.pop(None, None); self._pwd = p.__getitem__
    g.pop(None, None); self._grp = g.__getitem__

  def map_acl(self, acl):
    return acl and [(t, None if q is None else
        (self._pwd if t >> 3 == ACL_USER else self._grp)(q))
      for t, q in acl]

  @staticmethod
  @contextmanager
  def _preserve(path_list, parent=False):
    # XXX: This is the only place where fssync should not be terminated,
    #      or directories may end up with wrong permissions/timestamps.
    if parent:
      path_list = frozenset(map(os.path.dirname, path_list))
    parent_list = []
    try:
      for path in path_list:
        try:
          s = os.lstat(path)
        except FileNotFoundError:
          continue
        mode = None if os.access(path, 2) else stat.S_IMODE(s.st_mode)
        parent_list.append((path, s.st_mtime_ns, mode))
        if mode is not None:
          os.chmod(path, 0o700)
      yield
    finally:
      for path, mtime_ns, mode in parent_list:
        if mode is not None:
          os.chmod(path, mode)
        os.utime(path, ns=(UTIME_OMIT, mtime_ns), follow_symlinks=False)

  def removemany(self, path_list):
    with self._preserve(os.path.join(self.root, path)
        for path in set(map(os.path.dirname, path_list)).difference(path_list)):
      for path in path_list:
        self.remove(os.path.join(self.root, path))

  @staticmethod
  def isdir(path):
    try:
      return stat.S_ISDIR(os.lstat(path).st_mode)
    except FileNotFoundError:
      pass

  def remove(self, path):
    try:
      os.remove(path)
    except IsADirectoryError:
      os.rmdir(path)
    except PermissionError:
      os.chmod(os.path.dirname(path), 0o700)
      (os.rmdir if self.isdir(path) else os.remove)(path)
    except FileNotFoundError:
      pass

  def link(self, path, new_path):
    new_path = os.path.join(self.root, new_path)
    if os.path.lexists(new_path):
      if os.path.samestat(os.lstat(path), os.lstat(new_path)):
        return
      self.remove(new_path)
    else:
      self._makeparents(new_path)
    os.link(path, new_path)

  def rename(self, path, new_path):
    new_path = os.path.join(self.root, new_path)
    with self._preserve((path,), True):
      if os.path.lexists(new_path):
        self.remove(new_path)
      else:
        self._makeparents(new_path)
      os.rename(path, new_path)

  def rollback(self, path, other):
    if other is not None and self.isdir(path):
      other = os.path.join(self.root, other)
      with self._preserve((other,), True):
        # If src & dst are the same inode, this is a no-op.
        os.rename(path, other)
    self.remove(path)

  def null_value(self, path):
    try:
      return Stat(path).null_value
    except FileNotFoundError:
      pass

  @staticmethod
  def _makeparents(path):
    d = os.path.dirname(path)
    if os.path.lexists(d):
      if not os.access(d, 2):
        os.chmod(d, 0o700)
    else:
      os.makedirs(d, 0o700)

  def sync_meta(self, path, l):
    x = stat.S_IFMT(l.mode)
    mode = stat.S_IMODE(l.mode)
    size = l.size
    while 1:
      try:
        s = Stat(path)
      except FileNotFoundError:
        if x == stat.S_IFLNK or x == stat.S_IFREG and size:
          raise
      else:
        if x == stat.S_IFMT(s.mode):
          break
        self.remove(path)
      self._makeparents(path)
      if x == stat.S_IFDIR:
        os.mkdir(path, mode)
      elif x == stat.S_IFREG:
        os.close(os.open(path, os.O_CREAT, mode))
      else:
        os.mknod(path, l.mode, l.rdev)
    if x == stat.S_IFREG:
      self._close()
      if size < s.size or size <= s.blocks * 512 - s.blksize:
        # l.size > s.size implies that local file was truncated during sync:
        # we don't want the backup to become sparse and next run will fix this
        # We also check number of blocks in case we allocated more disk space
        # than necessary.
        try:
          os.truncate(path, size)
        except PermissionError:
          s.mode |= 0o200
          os.chmod(path, s.mode)
          os.truncate(path, size)
        if s.size <= size:
          # Do it twice if we don't reduce file size,
          # to force FS to free extra space.
          os.truncate(path, size)
    acl = self.map_acl(l.acl and l.acl[0])
    if acl != (s.acl and s.acl[0]):
        (load_acl(acl) if acl else ACL(mode=mode)).applyto(path)
    elif mode != stat.S_IMODE(s.mode):
      os.chmod(path, mode)
    x = self._pwd(l.uid), self._grp(l.gid)
    if x != (s.uid, s.gid):
      os.lchown(path, *x)
    x = l.mtime_ns
    if x != s.mtime_ns:
      os.utime(path, ns=(UTIME_OMIT, x), follow_symlinks=False)
    acl = self.map_acl(l.acl and l.acl[1])
    if acl != (s.acl and s.acl[1]):
      (load_acl(acl).applyto if acl else delete_default)(path)
    x = l.attr
    if x != s.attr:
      for attr in set(s.attr or ()).difference(x or ()):
        os.removexattr(path, attr, follow_symlinks=False)
      for attr, value in x or ():
        os.setxattr(path, attr, value, follow_symlinks=False)

  def _makereg(self, path):
    try:
      if stat.S_ISREG(os.lstat(path).st_mode):
        return True
    except FileNotFoundError:
      self._makeparents(path)
    else:
      self.remove(path)
    return False

  def check_data(self, path, start, size, block_size):
    if start or self._makereg(path):
      f = self._open(path)
      try:
        f.seek(start)
        return tuple(check_data(f, size, block_size))
      except:
        self._close()
        raise
    os.close(os.open(path, os.O_CREAT|os.O_EXCL, 0o600))
    return ()

  def sync_data(self, path, diff_list):
    try:
      f = self._open(path, 'r+b')
    except PermissionError:
      os.chmod(path, 0o600)
      f = self._open(path, 'r+b')
    try:
      for offset, data in diff_list:
        f.seek(offset)
        f.write(data)
    except:
      self._close()
      raise

  def symlink(self, path, target):
    self._makeparents(path)
    self.remove(path)
    os.symlink(target, path)

  def truncate(self, path, size, sparse):
    self._makereg(path)
    try:
      fd = os.open(path, os.O_WRONLY|os.O_CREAT, 0o600)
    except PermissionError:
      os.chmod(path, 0o600)
      fd = os.open(path, os.O_WRONLY)
    try:
      if sparse:
        os.ftruncate(fd, size)
      else:
        _fallocate(fd, 0, size, True)
    finally:
      os.close(fd)

  def check(self, item_list):
    path_list = []
    for path, metadata in item_list:
      l = Stat.load(Stat.NULL_KEY, metadata)
      l.uid = self._pwd(l.uid)
      l.gid = self._grp(l.gid)
      try:
        s = Stat(os.path.join(self.root, path))
        metadata = s.value
      except FileNotFoundError:
        metadata = None
      if l.value != metadata:
        logger.warning("%s: DB %r differs from FS %r", path, l.value, metadata)
        path_list.append(path)
    return path_list

class SysLogHandler(logging.handlers.SysLogHandler):

  __init__ = logging.Handler.__init__
  close = logging.Handler.close

  def emit(self, record):
    try:
      priority = self.priority_names[self.mapPriority(record.levelname)]
      for line in self.format(record).splitlines():
        syslog.syslog(priority, line)
    except Exception:
      self.handleError(record)


def main():
  parser = argparse.ArgumentParser(
    description="File system synchronization tool")
  _ = parser.add_argument
  _('--remote', action='store_true',
              help="Used on the remote host to process RPC from local host.")
  _('-l', '--logfile',
              help="Output logging messages to specified file. If unset,"
                   " they are logged to stderr (local) or to syslog (remote)."
                   " Passing /dev/null disables logging.")
  _('-r', '--root', required=True, type=encode,
              help="Root path of dirs/files to synchronize.")
  _('-v', '--verbose', action='store_true',
              help="Increase verbosity to DEBUG level. Default level is INFO.")
  local_options = (
    _('--reverse', action='store_true',
              help="Synchronize data from remote(-R, -L, -d, -f)"
                   " to local(-r, -l), instead of doing it"
                   " from local(-r, -l, -d, -f) to remote(-R, -L)."
                   " This option is useful when the SSH connection can be"
                   " established in only 1 way."
                   " It is not an option for 2-way synchronization:"
                   " data must always be synchronized in the same direction."
                   " This options conflits with --print0 and host=-"),
    _('--print0', action='store_true',
              help="Print filtered paths on the standard output, followed by a"
                   " null character. When used with --check, this prints"
                   " paths that don't match database. You can pipe result to"
                   " \"tr -s '\\0' '\\n'\" if you want newline separators."),
    _('-a', '--allocate', action='store_true',
              help="Preallocate disk space on destination file system for"
                   " non-sparse files. This reduces disk fragmentation but"
                   " prevent files from being compressed on Btrfs."),
    _('-c', '--check', action='store_true',
              help="Check database matches remote host, instead of"
                   " synchronizing. This does not check file contents,"
                   " symlink target and hardlinks."
                   " -a/-f have no effect with this option."),
    _('-d', '--db',
              help="File path to database that maintains"
                   " the state of files on the remote side."),
    _('-f', '--filter',
              help="Python expression that evaluates to a function"
                   " which takes 3 arguments (root, p, s) and returns"
                   " True if 'root/p' must be ignored. 's' holds some"
                   " stats about 'root/p': see 'Stat' class."),
    _('-m', '--map-users', action='store_true',
              help="Map uid/gid, including in ACLs, so that names don't change"
                   " between local and remote hosts. This requires remote to"
                   " define all users/groups that may appear locally."),
    _('-L', '--remote-logfile',
              help="Set '-l' option for remote end."),
    _('-R', '--remote-root',
              help="Set '-r' option for remote end."),
    _('-X', '--remote-executable',
              help="Path to fssync executable on remote side."
                   " If unset, it is guessed from the 0-th argument."),
    _('host', nargs='?',
              help='SSH to connect (syntax: [user@]host[:port]).'
                   " Can be '-' to initialize database, provided you"
                   " synchronized everything by other means (like rsync)."))
  _('path', nargs='*', type=encode,
              help="Synchronize only these entries"
                   " (paths must be relative to --root).")
  args = parser.parse_args()

  if not args.root:
    parser.error('empty -r/--root argument')

  if args.logfile == os.devnull:
    logging.disable(logging.CRITICAL)
  else:
    format = '%(asctime)s %(levelname)s %(message)s'
    if args.logfile:
      handler = logging.FileHandler(args.logfile)
    elif args.remote:
      handler = SysLogHandler()
      format = '%(message)s'
    else:
      handler = logging.StreamHandler()
    handler.setFormatter(logging.Formatter(format))
    root = logging.getLogger()
    root.setLevel(logging.DEBUG if args.verbose else logging.INFO)
    root.addHandler(handler)

  signal.signal(signal.SIGHUP, lambda *args: sys.exit(-1))
  signal.signal(signal.SIGTERM, lambda *args: sys.exit())

  if args.remote:
    for a in local_options:
      v = getattr(args, a.dest)
      if v is a.const if a.nargs == 0 else v is not None:
        parser.error('conflicting options --remote and ' +
          ('/'.join(a.option_strings) or a.metavar))

    Remote(args.root)(sys.stdin.buffer, sys.stdout.buffer.raw)

  else:
    if args.remote_executable is None:
      args.remote_executable = os.path.realpath(sys.argv[0])
    if args.remote_root is None:
      args.remote_root = encode(args.root)
    for a in local_options:
      if a.nargs != 0 and not (getattr(args, a.dest) or a.dest in (
          'filter', 'remote_logfile')):
        parser.error('empty or missing %s argument' %
          ('/'.join(a.option_strings) or a.dest))

    if args.host == '-':
      rpc = None
    else:
      opt = ['--remote', '-r', args.remote_root]
      if args.remote_logfile:
        opt += '-l', args.remote_logfile
      if args.verbose:
        opt.append('-v')
      rpc = RpcSshClient(args.host,
        format_command(args.remote_executable, *opt),
        args.map_users and not args.reverse)
    action = args.filter, args.path, args.check, args.print0, args.allocate
    if args.reverse:
      if rpc is None:
        parser.error("conflicting options --reverse and HOST=-")
      rpc.reverse(args.db, args.map_users, *action)
      Remote(args.root)(rpc.stdin, rpc.stdout)
    else:
      Local(args.root, args.db, rpc)(*action)

if __name__ == '__main__':
  sys.exit(main())
