#!/usr/bin/python

# @file set_ntp_server
#
# Copyright (C) Metaswitch Networks 2016
# If license terms are provided to you in a COPYING file in the root directory
# of the source code repository by which you are accessing this code, then
# the license outlined in that COPYING file applies to your use.
# Otherwise no rights are granted except for those provided to you by
# Metaswitch Networks in a separate written agreement.

import os, pipes, platform, subprocess, sys
from datetime import datetime

NTP_CONFIG="/etc/ntp.conf"
KEY="server"
SCRIPT=__file__

def set_ntp_server():
    lines = []

    with open(NTP_CONFIG, 'r') as f:
        for line in f:
            if line.startswith(KEY):
                # Comment line out
                lines.append('#' + line)
            else:
                lines.append(line)

    with open(NTP_CONFIG, 'w') as f:
        for line in lines:
            f.write(line)
        f.write("\n")
        now = datetime.now()
        f.write("# NTP Address Configuration added by {0} on {1}\n".format(SCRIPT, now))
        for num in range(1, len(sys.argv)):
            f.write("{0} {1}\n".format(KEY, pipes.quote(sys.argv[num])))

def restart_ntp():
    # Hide output
    with open("/dev/null", 'w') as null:
        if "CentOS" in platform.linux_distribution()[0]:
            service_name = "ntpd"
        else:
            service_name = "ntp"

        ret_code = subprocess.call("service {0} restart".format(service_name),
                                   shell=True,
                                   stdout=null,
                                   stderr=null)

        if ret_code != 0:
            print "Failed to restart ntp"
            sys.exit(3)

if __name__ == "__main__":
    if os.geteuid() != 0:
        print "This script requires root privileges"
        sys.exit(1)

    if len(sys.argv) < 2:
        print "Usage: {0} NTP_ADDRESS_1 NTP_ADDRESS_2 NTP_ADDRESS_3 ...".format(SCRIPT)
        sys.exit(2)

    set_ntp_server()

    restart_ntp()

    print "Updated NTP Address to: " + format(sys.argv[1:])

    sys.exit()
