task1 iptables转npf

263 阅读3分钟
""" NatSetting Impl """

import os
import json
import string
import re
import pdb
from common import MemConfig
from common.ConfigMgmt import addFeature
from common.logger import networkLogger
from common.IfManager import get_pppoe
from common.IfManager import get_pppoe_devicename
from common.Utilities import ip_to_num
from common.Utilities import num_to_ip
from common.Utilities import getIPRangeFromCIDR
from interface.fakeARP_Impl import fakeARP_Impl
from interface.NetIfSetting_Impl import NetIfSetting_Impl

key = "NatRule"
default_value = """
{
}
"""
NPFCONFIG = " /etc/snat.conf"
IPTABLES = "sudo /usr/vtm/scripts/iptablesEx.sh"


def getip(ethname):
    all = NetIfSetting_Impl.get_all_if_from_mem()
    NetIfSetting_Impl.update_all_if_status(all)
    return all[ethname]["ip"]


# notice:npf ip range can not choose too large or will have a error:the number of instructions is over the limit
def assemble_map(rule, index):
    cmds2 = []
    cmd2 = "sed -i "
    if (index == 1):
        os.system("sed -i \'/#map_start/a #rule1_end\'" + NPFCONFIG)
        os.system("sed -i \'/#map_start/a #rule1_start\'" + NPFCONFIG)
        cmd2 += "\'/#rule1_start/a" + " map "
    else:
        os.system("sed -i \'/#rule" + str(index - 1) + "_end/a #rule" + str(index) + "_end\'" + NPFCONFIG)
        os.system("sed -i \'/#rule" + str(index - 1) + "_end/a #rule" + str(index) + "_start\'" + NPFCONFIG)
        cmd2 += "\'/#rule" + str(index) + "_start/a" + " map "
    # get eth num and transform eth1 to $n1
    split = rule["translation"]["transform"].split(" ")
    replace = split[len(split) - 1].replace("eth", "")
    cmd2 += "$n" + replace + " dynamic "

    ttype = rule["translation"]["type"]
    if ttype == "SNAT":
        cmd2 += "0.0.0.0 -> "
        if rule["translation"]["transform"].find("egress") >= 0:
            rule["translation"]["transform"] = rule["translation"]["transform"].replace("egress", "")
            ttype = "MASQUERADE"
        if ttype == "MASQUERADE":
            cmd2 += "$n" + replace + " "
        elif ttype == "SNAT":
            cmd2_dip_arr = rule["translation"]["transform"].split(" ")
            cmd2_dip = cmd2_dip_arr[1]
            cmd2 += cmd2_dip + " "
        # transform "-m iprange --src-range 192.168.3.100-192.168.3.110" to "pass from 192.168.3.100-192.168.3.110"
        cmd2 += "pass from "
        sip = rule["source"]["ip"]
        if sip != "":
            cmd2 += sip + " "
        else:
            cmd2 += "any "
        # transform "-m multiport --sports 1000:2000" to "port 1000-2000"
        sport = rule["source"]["port"]
        if sport != "":
            cmd2 += "port " + sport + " "
        cmd2 += "to "
        # transfrom "-m iprange --dst-range 192.168.3.10-192.168.3.15" to "to 192.168.3.10-192.168.3.15"
        dip = rule["destination"]["ip"]
        if dip != "":
            cmd2 += dip + " "
        else:
            cmd2 += "any "
        dport = rule["destination"]["port"]
        if dport != "":
            cmd2 += "port " + dport

        cmd2 += "\'" + NPFCONFIG
        cmds2.append(cmd2)

    elif ttype == "DNAT":
        temp = " <- 0.0.0.0 pass from "
        sip = rule["source"]["ip"]
        if sip != "":
            temp += sip + " "
        else:
            temp += "any "
        sport = rule["source"]["port"]
        if sport != "":
            temp += "port " + sport + " "
        transform_ary = rule["translation"]["transform"].split(' ')

        # formalize the transform_ary[1] and turned it into ip1-ip2[:port1-port2]
        trans_ip_port_info = transform_ary[1]
        trans_ip_port_info_ary = trans_ip_port_info.split(":")
        trans_ip_section = trans_ip_port_info_ary[0]
        dip = rule["destination"]["ip"]
        dport = rule["destination"]["port"]
        trans_port_section = ""
        if not ("-" in trans_ip_section):
            transform_ary[1] = trans_ip_section + "-" + trans_ip_section
            rule["destination"]["ip"] += "-" + rule["destination"]["ip"]
        if len(trans_ip_port_info_ary) == 2:
            trans_port_section = trans_ip_port_info_ary[1]
            if not ("-" in trans_port_section) and trans_port_section != "":
                trans_port_section += "-" + trans_port_section
                transform_ary[1] += ":" + trans_port_section
        if dport != "":
            rule["destination"]["port"] += "-" + rule["destination"]["port"]

        rule["translation"]["transform"] = string.join(transform_ary, " ")

        # handle one-one mapping for ip and port
        trans_ip_port_info = transform_ary[1]
        trans_ip_port_info_ary = trans_ip_port_info.split(":")
        trans_ip_section = trans_ip_port_info_ary[0]
        dip = rule["destination"]["ip"]
        dport = rule["destination"]["port"]
        trans_ip_ary = trans_ip_section.split('-')

        trans_ip_start = ip_to_num(trans_ip_ary[0])
        trans_ip_end = ip_to_num(trans_ip_ary[1])

        # create list trans_ip_list and dst_ip_list
        trans_ip_list = []
        dst_ip_list = []
        dst_ip_ary = dip.split('-')
        if dst_ip_ary[0] != "ingress":
            dst_ip_start = ip_to_num(dst_ip_ary[0])
            dst_ip_end = ip_to_num(dst_ip_ary[1])

        if dst_ip_ary[0] != "ingress":
            for i in range(trans_ip_end - trans_ip_start + 1):
                dst_ip_list.append(num_to_ip(dst_ip_start + i))
                trans_ip_list.append(num_to_ip(trans_ip_start + i))
        else:
            dst_ip_list.append("ingress")
            trans_ip_list.append(num_to_ip(trans_ip_start))

        # create list trans_port_list and dst_port_list
        dst_port_list = []
        trans_port_list = []
        if dport != "":
            dst_port_ary = dport.split('-')
            dst_port_start = int(dst_port_ary[0])
            dst_port_end = int(dst_port_ary[1])
            for i in range(dst_port_end - dst_port_start + 1):
                dst_port_list.append(str(dst_port_start + i))
        if trans_port_section != "":
            trans_port_ary = trans_port_section.split('-')
            trans_port_start = int(trans_port_ary[0])
            trans_port_end = int(trans_port_ary[1])
            for i in range(trans_port_end - trans_port_start + 1):
                trans_port_list.append(str(trans_port_start + i))

        # create iptables command list according to combinations
        for i in range(len(trans_ip_list)):
            transform_instance_ary = transform_ary[:]
            temp_instance = temp
            if dst_ip_ary[0] != "ingress":
                temp_instance += " to " + dst_ip_list[i] + " "
            else:
                temp_instance += " to any "
            transform_instance_ary[1] = trans_ip_list[i]
            if len(trans_port_list) != 0:  # then dst_port_list will also not empty
                for j in range(len(trans_port_list)):
                    cmd_instance = cmd2
                    temp_instance_with_port = temp_instance
                    tranform_instance_with_port_ary = transform_ary[:]
                    tranform_instance_with_port_ary[1] = transform_instance_ary[1]
                    temp_instance_with_port += " port " + dst_port_list[j]

                    cmd_instance += transform_instance_ary[1] + " port " + trans_port_list[j]
                    map_instance = cmd_instance + temp_instance_with_port
                    map_instance += "\'" + NPFCONFIG
                    cmds2.append(map_instance)
            elif len(dst_port_list) != 0:  # else when dst_port_list is not empty, it's just matching condition
                dport_arr = dport.split("-")
                temp_instance += " port " + dport_arr[0] + "-" + dport_arr[1]
                cmd2_instance = cmd2
                cmd2_instance += transform_instance_ary[1]
                map_instance = cmd2_instance + temp_instance
                map_instance += "\'" + NPFCONFIG
                cmds2.append(map_instance)
            else:  # else no port condition
                cmd2_instance = cmd2
                cmd2_instance += transform_instance_ary[1]
                map_instance = cmd2_instance + temp_instance
                map_instance += "\'" + NPFCONFIG
                cmds2.append(map_instance)
    return cmds2


def assemble(rule):
    cmds = []
    cmd = ""
    cmd += IPTABLES + " -t nat "

    ttype = rule["translation"]["type"]
    if ttype == "SNAT":
        cmd += "-A _SNAT "
        if rule["translation"]["transform"].find("egress") >= 0:
            rule["translation"]["transform"] = rule["translation"]["transform"].replace("egress", "")
            ttype = "MASQUERADE"

        protocol = rule["protocol"]
        if protocol != "":
            cmd += "-p " + protocol + " "
        sip = rule["source"]["ip"]
        if "-" in sip:
            cmd += "-m iprange --src-range " + sip + " "
        elif sip != "":
            cmd += "-s " + sip + " "
        sport = rule["source"]["port"].replace("-", ":")
        if sport != "":
            cmd += "-m multiport --sports " + sport + " "
        dip = rule["destination"]["ip"]
        if "-" in dip:
            cmd += "-m iprange --dst-range " + dip + " "
        elif dip != "":
            cmd += "-d " + dip + " "

        dport = rule["destination"]["port"].replace("-", ":")
        if dport != "":
            cmd += "-m multiport --dports " + dport + " "

        # process cidr to ip range
        if '/' in rule["translation"]["transform"]:
            trans_tmp = rule["translation"]["transform"].split(" ")
            trans_tmp[1] = getIPRangeFromCIDR(trans_tmp[1])
            rule["translation"]["transform"] = string.join(trans_tmp, " ")

        cmd += " -j " + ttype + " " + \
               map_pppoe(rule["translation"]["transform"])
        if dport != "":
            cmd += " port " + dport + " "
        cmds.append(cmd)

    elif ttype == "DNAT":
        cmd += "-A _DNAT "

        protocol = rule["protocol"]
        if protocol != "":
            cmd += "-p " + protocol + " "

        sip = rule["source"]["ip"]
        if "-" in sip:
            cmd += "-m iprange --src-range " + sip + " "
        elif sip != "":
            cmd += "-s " + sip + " "

        sport = rule["source"]["port"].replace("-", ":")
        if sport != "":
            cmd += "-m multiport --sports " + sport + " "

        transform_ary = rule["translation"]["transform"].split(' ')

        # formalize the transform_ary[1] and turned it into ip1-ip2[:port1-port2]
        trans_ip_port_info = transform_ary[1]
        trans_ip_port_info_ary = trans_ip_port_info.split(":")
        trans_ip_section = trans_ip_port_info_ary[0]
        dip = rule["destination"]["ip"]
        dport = rule["destination"]["port"]

        trans_ip_section = trans_ip_port_info_ary[0]
        trans_port_section = ""
        if not ("-" in trans_ip_section):
            transform_ary[1] = trans_ip_section + "-" + trans_ip_section
            rule["destination"]["ip"] += "-" + rule["destination"]["ip"]
        if len(trans_ip_port_info_ary) == 2:
            trans_port_section = trans_ip_port_info_ary[1]
            if not ("-" in trans_port_section) and trans_port_section != "":
                trans_port_section += "-" + trans_port_section
                transform_ary[1] += ":" + trans_port_section
        if dport != "":
            rule["destination"]["port"] += "-" + rule["destination"]["port"]

        rule["translation"]["transform"] = string.join(transform_ary, " ")
        # handle one-one mapping for ip and port
        trans_ip_port_info = transform_ary[1]
        trans_ip_port_info_ary = trans_ip_port_info.split(":")
        trans_ip_section = trans_ip_port_info_ary[0]
        dip = rule["destination"]["ip"]
        dport = rule["destination"]["port"]

        trans_ip_ary = trans_ip_section.split('-')
        trans_ip_start = ip_to_num(trans_ip_ary[0])
        trans_ip_end = ip_to_num(trans_ip_ary[1])
        # create list trans_ip_list and dst_ip_list
        trans_ip_list = []
        dst_ip_list = []
        dst_ip_ary = dip.split('-')
        if dst_ip_ary[0] != "ingress":
            dst_ip_start = ip_to_num(dst_ip_ary[0])
            dst_ip_end = ip_to_num(dst_ip_ary[1])

        if dst_ip_ary[0] != "ingress":
            for i in range(trans_ip_end - trans_ip_start + 1):
                dst_ip_list.append(num_to_ip(dst_ip_start + i))
                trans_ip_list.append(num_to_ip(trans_ip_start + i))
        else:
            dst_ip_list.append("ingress")
            trans_ip_list.append(num_to_ip(trans_ip_start))
        # create list trans_port_list and dst_port_list
        dst_port_list = []
        trans_port_list = []
        if dport != "":
            dst_port_ary = dport.split('-')
            dst_port_start = int(dst_port_ary[0])
            dst_port_end = int(dst_port_ary[1])
            for i in range(dst_port_end - dst_port_start + 1):
                dst_port_list.append(str(dst_port_start + i))
        if trans_port_section != "":
            trans_port_ary = trans_port_section.split('-')
            trans_port_start = int(trans_port_ary[0])
            trans_port_end = int(trans_port_ary[1])
            for i in range(trans_port_end - trans_port_start + 1):
                trans_port_list.append(str(trans_port_start + i))
        # create iptables command list according to combinations

        for i in range(len(trans_ip_list)):
            transform_instance_ary = transform_ary[:]
            cmd_instance = cmd
            if dst_ip_ary[0] != "ingress":
                cmd_instance += " -d " + dst_ip_list[i]
            # removed by SEG-16713
            # if dst_ip_ary[0] == "ingress":
            #    cmd_instance += " -d " + getip(str(transform_ary[-1])) 
            # cmd_instance += " -d " + dst_ip_list[i]
            transform_instance_ary[1] = trans_ip_list[i]

            if len(trans_port_list) != 0:  # then dst_port_list will also not empty
                for j in range(len(trans_port_list)):
                    cmd_instance_with_port = cmd_instance
                    tranform_instance_with_port_ary = transform_ary[:]
                    tranform_instance_with_port_ary[1] = transform_instance_ary[1]
                    cmd_instance_with_port += " -m multiport --dports " + dst_port_list[j] + " -j " + ttype
                    tranform_instance_with_port_ary[1] += ":" + trans_port_list[j]
                    cmd_instance_with_port += " " + string.join(tranform_instance_with_port_ary, " ")
                    cmds.append(cmd_instance_with_port)
            elif len(dst_port_list) != 0:  # else when dst_port_list is not empty, it's just matching condition
                tmp = dport.split('-')
                if len(tmp) == 1 or tmp[0] == tmp[1]:
                    dport = tmp[0]
                else:
                    dport = tmp[0] + ":" + tmp[1]
                cmd_instance += " -m multiport --dports " + dport + " -j " + ttype + " " + string.join(
                    transform_instance_ary, " ")
                cmds.append(cmd_instance)
            else:  # else no port condition
                cmd_instance += " -j " + ttype + " " + string.join(transform_instance_ary, " ")
                cmds.append(cmd_instance)

    for i in range(len(cmds)):
        cmds[i] = map_pppoe(cmds[i])
    return cmds


def map_pppoe(trans):
    if "eth" not in trans:
        return trans
    else:
        tmp = trans.split(" ")
        for i in range(len(tmp)):
            if "eth" in tmp[i] and tmp[i] in get_pppoe():
                tmp[i] = get_pppoe_devicename(tmp[i])
        return " ".join(tmp)


def get_all_rule_from_mem():
    return json.loads(MemConfig.loadOneConfig(key))


def get_rule_from_mem(rule):
    return get_all_rule_from_mem()[rule]


def set_all_rule_to_mem(all):
    for it in all:
        if not all[it].has_key("description"):
            all[it]["description"] = ""
    MemConfig.setConfig(key, json.dumps(all, indent=4))


def set_rule_to_mem(rule, value):
    all = get_all_rule_from_mem()
    if value != "":
        if not value.has_key("description"):
            value["description"] = ""
        all[rule] = value
    else:
        new = {}
        i = 1
        j = 1
        while i <= len(all):
            if rule == "rule" + str(i):
                i += 1
            else:
                new["rule" + str(j)] = all["rule" + str(i)]
                i += 1
                j += 1
        all = new
    set_all_rule_to_mem(all)


def filter_rule(type):
    all = get_all_rule_from_mem()
    result = []
    i = 1
    while i <= len(all):
        ttype = all["rule" + str(i)]["translation"]["type"]
        if type == "_SNAT":
            if ttype == "SNAT" or ttype == "MASQUERADE":
                result.append(all["rule" + str(i)])
        elif type == "_DNAT":
            if ttype == "DNAT":
                result.append(all["rule" + str(i)])
        i += 1
    return result


def apply_rules(rules):
    if (len(rules) == 0):
        return True
    # do not translate local address
    os.system("sudo iptables -t nat -A _SNAT -s 127.0.0.1/32 -j RETURN >/dev/null 2>&1")
    os.system("sudo iptables -t nat -A _DNAT -d 127.0.0.1/32 -j RETURN >/dev/null 2>&1")
    for i in rules:
        cmds = assemble(i)
        for cmd in cmds:
            networkLogger.info("cmd: " + cmd)
            if os.system(cmd + " >/dev/null 2>&1") != 0:
                networkLogger.debug("cmd failed")
    return True


def apply_npf_rules(rules):
    index = 0
    for rule in rules:
        index += 1
        cmds2 = assemble_map(rule, index)
        for cmd2 in cmds2:
            networkLogger.debug("cmd2: ")
            networkLogger.debug(cmd2)
            if os.system(cmd2 + " >/dev/null 2>&1") != 0:
                networkLogger.error("cmd2 failed")
    if os.system("sudo fp-npfctl reload " + NPFCONFIG + " >/dev/null 2>&1") != 0:
        networkLogger.error("npf reload failed!")
    return True


def flush(table, chain):
    cmd = IPTABLES + " -t " + table + " -F " + chain + " >/dev/null 2>&1"
    networkLogger.debug("cmd: " + cmd)
    os.system(cmd)


def npf_flush():
    os.system(" sed -i  '/#map_start/,/#map_end/{/#map_start/!{/#map_end/!d}}' " + NPFCONFIG)


def submitFakeARPRequest(rules):
    for rule in rules.keys():
        if rules[rule]["translation"]["type"] == "SNAT":
            if rules[rule]["translation"]["transform"].find("egress") == -1:
                tmp = re.split("[ ]+", rules[rule]["translation"]["transform"])
                range = tmp[1].split(":")[0]
                if rules[rule]["translation"]["transform"].find("-o") == -1:
                    fakeARP_Impl.addNATRequirement("*", rule, range)
                else:
                    idx = tmp.index("-o") + 1
                    fakeARP_Impl.addNATRequirement(tmp[idx], rule, range)
        elif rules[rule]["translation"]["type"] == "DNAT":
            if rules[rule]["destination"]["ip"] not in ("", "ingress"):
                tmp = re.split("[ ]+", rules[rule]["translation"]["transform"])
                range = rules[rule]["destination"]["ip"]
                if rules[rule]["translation"]["transform"].find("-i") == -1:
                    fakeARP_Impl.addNATRequirement("*", rule, range)
                else:
                    idx = tmp.index("-i") + 1
                    fakeARP_Impl.addNATRequirement(tmp[idx], rule, range)


def apply_vpn_nat(rules):
    for i in rules:
        cmds = assemble(rules[i])
        for cmd in cmds:
            if os.system(cmd + " >/dev/null 2>&1") != 0:
                networkLogger.debug("cmd failed")
                return False


def get_rule_tolist():
    all = get_all_rule_from_mem()
    result = []
    i = 1
    while i <= len(all):
        result.append(all["rule" + str(i)])
        i += 1
    return result


def apply_all(vpn_remove_list=[]):
    npf_flush()
    flush("nat", "_SNAT")
    flush("nat", "_DNAT")

    # l2tp over ipsec nat rule
    if not "l2tp" in vpn_remove_list:
        l2tpvpnSetting = json.loads(MemConfig.loadOneConfig("L2tpSetting"))
        l2tpvpnRule = l2tpvpnSetting["vpnNatRule"]
        apply_vpn_nat(l2tpvpnRule)

    if not "ssl" in vpn_remove_list:
        # ssl vpn nat rule
        openvpnSetting = json.loads(MemConfig.loadOneConfig("OpenVpnSetting"))
        openvpnRule = openvpnSetting["vpnNatRule"]
        apply_vpn_nat(openvpnRule)

    if not apply_rules(filter_rule("_SNAT")): return False
    if not apply_rules(filter_rule("_DNAT")): return False
    if not apply_npf_rules(get_rule_tolist()): return False

    fakeARP_Impl.flushFakeARPNATRules()
    # ssl vpn nat arp
    # submitFakeARPRequest(openvpnRule)
    submitFakeARPRequest(get_all_rule_from_mem())
    fakeARP_Impl.restartFakeARPD()

    return True


def init():
    MemConfig.setConfig(key, default_value)
    addFeature(key)


def bypass_snat_for_ipsec(remote_networks):
    flush("nat", "_IPSEC")
    for network in remote_networks:
        cmd = IPTABLES + " -t nat -A _IPSEC -d " + network + " -j ACCEPT >/dev/null 2>&1"
        networkLogger.info("cmd: " + cmd)
        os.system(cmd)