#!/usr/bin/python
# (c) 2014, Kevin Carter <kevin.carter@rackspace.com>
#
# Copyright 2014, Rackspace US, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import base64
import os
import stat
import sys

import memcache
try:
    from Crypto.Cipher import AES
    from Crypto import Random

    ENCRYPT_IMPORT = True
except ImportError:
    ENCRYPT_IMPORT = False

# import module snippets
from ansible.module_utils.basic import *

DOCUMENTATION = """
---
module: memcached
version_added: "1.6.6"
short_description:
   - Add, remove, and get items from memcached
description:
   - Add, remove, and get items from memcached
options:
    name:
        description:
            - Memcached key name
        required: true
    content:
        description:
            - Add content to memcached. Only used when state is 'present'.
        required: false
    file_path:
        description:
            - This can be used with state 'present' and 'retrieve'. When set
              with state 'present' the contents of a file will be used, when
              set with state 'retrieve' the contents of the memcached key will
              be written to a file.
        required: false
    state:
        description:
            - ['absent', 'present', 'retrieve']
        required: true
    server:
        description:
            - server IP address and port. This can be a comma separated list of
              servers to connect to.
        required: true
    encrypt_string:
        description:
            - Encrypt/Decrypt a memcached object using a provided value.
        required: false
    dir_mode:
        description:
            - If a directory is created when using the ``file_path`` argument
              the directory will be created with a set mode.
        default: '0755'
        required: false
    file_mode:
        description:
            - If a file is created when using the ``file_path`` argument
              the file will be created with a set mode.
        default: '0644'
        required: false
    expires:
        description:
            - Seconds until an item is expired from memcached.
        default: 300
        required: false
notes:
    - The "absent" state will remove an item from memcached.
    - The "present" state will place an item from a string or a file into
      memcached.
    - The "retrieve" state will get an item from memcached and return it as a
      string. If a ``file_path`` is set this module will also write the value
      to a file.
    - All items added into memcached are base64 encoded.
    - All items retrieved will attempt base64 decode and return the string
      value if not applicable.
    - Items retrieve from memcached are returned within a "value" key unless
      a ``file_path`` is specified which would then write the contents of the
      memcached key to a file.
    - The ``file_path`` and ``content`` fields are mutually exclusive.
    - If you'd like to encrypt items in memcached PyCrypto is a required.
requirements:
    - "python-memcached"
optional_requirements:
    - "pycrypto"
author: Kevin Carter
"""

EXAMPLES = """
# Add an item into memcached.
- memcached:
    name: "key_name"
    content: "Super awesome value"
    state: "present"
    server: "localhost:11211"

# Read the contents of a memcached key, returned as "memcached_phrase.value".
- memcached:
    name: "key_name"
    state: "retrieve"
    server: "localhost:11211"
    register: memcached_key

# Add the contents of a file into memcached.
- memcached:
    name: "key_name"
    file_path: "/home/user_name/file.txt"
    state: "present"
    server: "localhost:11211"

# Write the contents of a memcached key to a file and is returned as
# "memcached_phrase.value".
- memcached:
    name: "key_name"
    file_path: "/home/user_name/file.txt"
    state: "retrieve"
    server: "localhost:11211"
    register: memcached_key

# Delete an item from memcached.
- memcached:
    name: "key_name"
    state: "absent"
    server: "localhost:11211"
"""

SERVER_MAX_VALUE_LENGTH = 1024 * 256

MAX_MEMCACHED_CHUNKS = 256


class AESCipher(object):
    """Encrypt an a string in using AES.

    Solution derived from "http://stackoverflow.com/a/21928790"
    """
    def __init__(self, key):
        if ENCRYPT_IMPORT is False:
            raise ImportError(
                'PyCrypto failed to be imported. Encryption is not supported'
                ' on this system until PyCrypto is installed.'
            )

        self.bs = 32
        if len(key) >= 32:
            self.key = key[:32]
        else:
            self.key = self._pad(key)

    def encrypt(self, raw):
        """Encrypt raw message.

        :param raw: ``str``
        :returns: ``str``  Base64 encoded string.
        """
        raw = self._pad(raw)
        iv = Random.new().read(AES.block_size)
        cipher = AES.new(self.key, AES.MODE_CBC, iv)
        return base64.b64encode(iv + cipher.encrypt(raw))

    def decrypt(self, enc):
        """Decrypt an encrypted message.

        :param enc: ``str``
        :returns: ``str``
        """
        enc = base64.b64decode(enc)
        iv = enc[:AES.block_size]
        cipher = AES.new(self.key, AES.MODE_CBC, iv)
        return self._unpad(cipher.decrypt(enc[AES.block_size:]))

    def _pad(self, string):
        """Pad an AES encryption key.

        :param string: ``str``
        """
        base = (self.bs - len(string) % self.bs)
        back = chr(self.bs - len(string) % self.bs)
        return string + base * back

    @staticmethod
    def _unpad(string):
        """Un-pad an AES encryption key.

        :param string: ``str``
        """
        ordinal_range = ord(string[len(string) - 1:])
        return string[:-ordinal_range]


class Memcached(object):
    """Manage objects within memcached."""
    def __init__(self, module):
        self.module = module
        self.state_change = False
        self.mc = None

    def router(self):
        """Route all commands to their respected functions.

        If an exception happens a failure will be raised.
        """

        try:
            action = getattr(self, self.module.params['state'])
            self.mc = memcache.Client(
                self.module.params['server'].split(','),
                server_max_value_length=SERVER_MAX_VALUE_LENGTH,
                debug=0
            )
            facts = action()
        except Exception as exp:
            self._failure(error=str(exp), rc=1, msg='general exception')
        else:
            self.mc.disconnect_all()
            self.module.exit_json(
                changed=self.state_change, **facts
            )

    def _failure(self, error, rc, msg):
        """Return a Failure when running an Ansible command.

        :param error: ``str``  Error that occurred.
        :param rc: ``int``     Return code while executing an Ansible command.
        :param msg: ``str``    Message to report.
        """

        self.module.fail_json(msg=msg, rc=rc, err=error)

    def absent(self):
        """Remove a key from memcached.

        If the value is not deleted when instructed to do so an exception will
        be raised.

        :return: ``dict``
        """

        key_name = self.module.params['name']
        get_keys = [
            '%s.%s' % (key_name, i) for i in range(MAX_MEMCACHED_CHUNKS)
        ]
        self.mc.delete_multi(get_keys)
        value = self.mc.get_multi(get_keys)
        if not value:
            self.state_change = True
            return {'absent': True, 'key': self.module.params['name']}
        else:
            self._failure(
                error='Memcache key not deleted',
                rc=1,
                msg='Failed to remove an item from memcached please check your'
                    ' memcached server for issues. If you are load balancing'
                    ' memcached, attempt to connect to a single node.'
            )

    @staticmethod
    def _decode_value(value):
        """Return a ``str`` from a base64 decoded value.

        If the content is not a base64 ``str`` the raw value will be returned.

        :param value: ``str``
        :return:
        """

        try:
            b64_value = base64.decodestring(value)
        except Exception:
            return value
        else:
            return b64_value

    def _encode_value(self, value):
        """Return a base64 encoded value.

        If the value can't be base64 encoded an excption will be raised.

        :param value: ``str``
        :return: ``str``
        """

        try:
            b64_value = base64.encodestring(value)
        except Exception as exp:
            self._failure(
                error=str(exp),
                rc=1,
                msg='The value provided can not be Base64 encoded.'
            )
        else:
            return b64_value

    def _file_read(self, full_path, pass_on_error=False):
        """Read the contents of a file.

        This will read the contents of a file. If the ``full_path`` does not
        exist an exception will be raised.

        :param full_path: ``str``
        :return: ``str``
        """

        try:
            with open(full_path, 'rb') as f:
                o_value = f.read()
        except IOError as exp:
            if pass_on_error is False:
                self._failure(
                    error=str(exp),
                    rc=1,
                    msg="The file you've specified does not exist. Please"
                        " check your full path @ [ %s ]." % full_path
                )
            else:
                return None
        else:
            return o_value

    def _chown(self, path, mode_type):
        """Chown a file or directory based on a given mode type.

        If the file is modified the state will be changed.

        :param path: ``str``
        :param mode_type: ``str``
        """
        mode = self.module.params.get(mode_type)
        # Ensure that the mode type is a string.
        mode = str(mode)
        _mode = oct(stat.S_IMODE(os.stat(path).st_mode))
        if _mode != mode or _mode[1:] != mode:
            os.chmod(path, int(mode, 8))
            self.state_change = True

    def _file_write(self, full_path, value):
        """Write the contents of ``value`` to the ``full_path``.

        This will return True upon success and will raise an exception upon
        failure.

        :param full_path: ``str``
        :param value: ``str``
        :return: ``bol``
        """

        try:
            # Ensure that the directory exists
            dir_path = os.path.dirname(full_path)
            try:
                os.makedirs(dir_path)
            except OSError as exp:
                if exp.errno == errno.EEXIST and os.path.isdir(dir_path):
                    pass
                else:
                    self._failure(
                        error=str(exp),
                        rc=1,
                        msg="The directory [ %s ] does not exist and couldn't"
                            " be created. Please check the path and that you"
                            " have permission to write the file."
                    )

            # Ensure proper directory permissions
            self._chown(path=dir_path, mode_type='dir_mode')

            # Write contents of a cached key to a file.
            with open(full_path, 'wb') as f:
                if isinstance(value, list):
                    f.writelines(value)
                else:
                    f.write(value)

            # Ensure proper file permissions
            self._chown(path=full_path, mode_type='file_mode')

        except IOError as exp:
            self._failure(
                error=str(exp),
                rc=1,
                msg="There was an issue while attempting to write to the"
                    " file [ %s ]. Please check your full path and"
                    " permissions." % full_path
            )
        else:
            return True

    def retrieve(self):
        """Return a value from memcached.

        If ``file_path`` is specified the value of the memcached key will be
        written to a file at the ``file_path`` location. If the value of a key
        is None, an exception will be raised.

        :returns: ``dict``
        """

        key_name = self.module.params['name']
        get_keys = [
            '%s.%s' % (key_name, i) for i in range(MAX_MEMCACHED_CHUNKS)
        ]
        multi_value = self.mc.get_multi(get_keys)
        if multi_value:
            value = ''.join([i for i in multi_value.values() if i is not None])
            # Get the file path if specified.
            file_path = self.module.params.get('file_path')
            if file_path is not None:
                full_path = os.path.abspath(os.path.expanduser(file_path))

                # Decode cached value
                encrypt_string = self.module.params.get('encrypt_string')
                if encrypt_string:
                    _d_value = AESCipher(key=encrypt_string)
                    d_value = _d_value.decrypt(enc=value)
                    if not d_value:
                        d_value = self._decode_value(value=value)
                else:
                    d_value = self._decode_value(value=value)

                o_value = self._file_read(
                    full_path=full_path, pass_on_error=True
                )

                # compare old value to new value and write if different
                if o_value != d_value:
                    self.state_change = True
                    self._file_write(full_path=full_path, value=d_value)

                return {
                    'present': True,
                    'key': self.module.params['name'],
                    'value': value,
                    'file_path': full_path
                }
            else:
                return {
                    'present': True,
                    'key': self.module.params['name'],
                    'value': value
                }
        else:
            self._failure(
                error='Memcache key not found',
                rc=1,
                msg='The key you specified was not found within memcached. '
                    'If you are load balancing memcached, attempt to connect'
                    ' to a single node.'
            )

    def present(self):
        """Create and or update a key within Memcached.

        The state processed here is present. This state will ensure that
        content is written to a memcached server. When ``file_path`` is
        specified the content will be read in from a file.
        """

        file_path = self.module.params.get('file_path')
        if file_path is not None:
            full_path = os.path.abspath(os.path.expanduser(file_path))
            # Read the contents of a file into memcached.
            o_value = self._file_read(full_path=full_path)
        else:
            o_value = self.module.params['content']

        # Encode cached value
        encrypt_string = self.module.params.get('encrypt_string')
        if encrypt_string:
            _d_value = AESCipher(key=encrypt_string)
            d_value = _d_value.encrypt(raw=o_value)
        else:
            d_value = self._encode_value(value=o_value)

        compare = 1024 * 128
        chunks = sys.getsizeof(d_value) / compare
        if chunks == 0:
            chunks = 1
        elif chunks > MAX_MEMCACHED_CHUNKS:
            self._failure(
                error='Memcache content too large',
                rc=1,
                msg='The content that you are attempting to cache is larger'
                    ' than [ %s ] megabytes.'
                    % ((compare * MAX_MEMCACHED_CHUNKS / 1024 / 1024))
            )

        step = len(d_value) / chunks
        if step == 0:
            step = 1

        key_name = self.module.params['name']
        split_d_value = {}
        count = 0
        for i in range(0, len(d_value), step):
            split_d_value['%s.%s' % (key_name, count)] = d_value[i:i + step]
            count += 1

        value = self.mc.set_multi(
            mapping=split_d_value,
            time=self.module.params['expires'],
            min_compress_len=2048
        )

        if not value:
            self.state_change = True
            return {
                'present': True,
                'key': self.module.params['name']
            }
        else:
            self._failure(
                error='Memcache content not created',
                rc=1,
                msg='The content you attempted to place within memcached'
                    ' was not created. If you are load balancing'
                    ' memcached, attempt to connect to a single node.'
                    ' Returned a value of unstored keys [ %s ] - Original'
                    ' Connection [ %s ]'
                    % (value, [i.__dict__ for i in self.mc.servers])
            )


def main():
    """Main ansible run method."""
    module = AnsibleModule(
        argument_spec=dict(
            name=dict(
                type='str',
                required=True
            ),
            content=dict(
                type='str',
                required=False
            ),
            file_path=dict(
                type='str',
                required=False
            ),
            state=dict(
                type='str',
                required=True
            ),
            server=dict(
                type='str',
                required=True
            ),
            expires=dict(
                type='int',
                default=300,
                required=False
            ),
            file_mode=dict(
                type='str',
                default='0644',
                required=False
            ),
            dir_mode=dict(
                type='str',
                default='0755',
                required=False
            ),
            encrypt_string=dict(
                type='str',
                required=False
            )
        ),
        supports_check_mode=False,
        mutually_exclusive=[
            ['content', 'file_path']
        ]
    )
    ms = Memcached(module=module)
    ms.router()

if __name__ == '__main__':
    main()
