#!/usr/bin/python
#
# Copyright (c) 2017 Mellanox Technologies. All rights reserved.
#
# This Software is licensed under one of the following licenses:
#
# 1) under the terms of the "Common Public License 1.0" a copy of which is
#    available from the Open Source Initiative, see
#    http://www.opensource.org/licenses/cpl.php.
#
# 2) under the terms of the "The BSD License" a copy of which is
#    available from the Open Source Initiative, see
#    http://www.opensource.org/licenses/bsd-license.php.
#
# 3) under the terms of the "GNU General Public License (GPL) Version 2" a
#    copy of which is available from the Open Source Initiative, see
#    http://www.opensource.org/licenses/gpl-license.php.
#
# Licensee has the right to choose one of the above licenses.
#
# Redistributions of source code must retain the above copyright
# notice and one of the license notices.
#
# Redistributions in binary form must reproduce both the above copyright
# notice, one of the license notices in the documentation
# and/or other materials provided with the distribution.
#

import sys
import os
if os.path.exists('/usr/share/pyshared'):
	sys.path.append('/usr/share/pyshared')
from optparse import OptionParser
from dcbnetlink import DcbController
from collections import defaultdict
from subprocess import Popen, PIPE

DCB_CAP_DCBX_HOST = 0x1
DCB_CAP_DCBX_LLD_MANAGED = 0x2
DCB_CAP_DCBX_VER_CEE = 0x4
DCB_CAP_DCBX_VER_IEEE = 0x8
DCB_CAP_DCBX_STATIC = 0x10

IEEE_8021QAZ_TSA_STRICT = 0
IEEE_8021QAZ_TSA_CB_SHAPER = 1
IEEE_8021QAZ_TSA_ETS = 2
IEEE_8021QAZ_TSA_VENDOR = 255

class Maxrate:
	def get(self):
		pass
	def set(self, ratelimit):
		pass

	def prepare(self, ratelimit):
		old_ratelimit = self.get()
		ratelimit += old_ratelimit[len(ratelimit):8]
		return ratelimit

class MaxrateNL(Maxrate):
	def __init__(self, ctrl):
		self.ctrl = ctrl
	def get(self):
		return ctrl.get_ieee_maxrate()

	def set(self, ratelimit):
		ratelimit = self.prepare(ratelimit)
		ctrl.set_ieee_maxrate(ratelimit)

class MaxrateSysfs(Maxrate):
	def __init__(self, path):
		self.path = path

	def get(self):
		ratelimit = []
		f = open(self.path, "r")
		for item in f.read().split():
			ratelimit.append(float(item))
		f.close()

		return ratelimit

	def set(self, ratelimit):
		ratelimit = self.prepare(ratelimit)
		f = open(self.path, "w")
		f.write(" ".join(str(r) for r in ratelimit))
		f.close()

class Trust:
	def get(self):
		pass
	def set(self, trust):
		pass

	def to_int(self, trust):
		if (trust == "pcp"):
			return 1
		elif (trust == "dscp"):
			return 2
		elif (trust == 1 or trust == 2):
			return trust
		else:
			return 0xFF

	def to_string(self, trust):
		if (trust == 1):
			return "pcp"
		elif (trust == 2):
			return "dscp"
		elif (trust == "pcp" or trust == "dscp"):
			return trust
		else:
			return "none"

class TrustNL(Trust):
	def __init__(self, ctrl):
		self.ctrl = ctrl
	def get(self):
		return ctrl.get_ieee_trust()
	def set(self, trust):
		trust = self.to_int(trust)
		ctrl.set_ieee_trust(trust)

class TrustSysfs(Trust):
	def __init__(self, path):
		self.path = path
	def get(self):
		f = open(self.path, "r")
		trust = f.read().replace('\n', '')
		f.close()
		return trust
	def set(self, trust):
		f = open(self.path, "w")
		f.write(trust)
		f.close()


def pretty_print(prio_tc, tsa, tcbw, ratelimit, pfc_en, trust):
	print ("Priority trust mode: " + trust)

	tc2up = defaultdict(list)

	if (printall == True):
		for i in range(8):
			tc2up.setdefault(i,[])

	print("PFC configuration:")
	print("\tpriority    0   1   2   3   4   5   6   7")
	msg = "\tenabled     "
	for up in range(8):
		msg += "%1d   " % ((pfc_en >> up) & 0x01)
	print(msg)
	print("")

	for up in range(len(prio_tc)):
		tc = prio_tc[up]
		tc2up[int(tc)].append(up)

	for tc in tc2up:
		r = "unlimited"
		msg = ""
		try:
			if ratelimit[tc] > 0:
				r = "%.1f Gbps" % (float(ratelimit[tc] / 1000)/1000)
			msg = "tc: %d ratelimit: %s, tsa: " % (tc, r)
		except Exception as err:
			pass
		try:
			if (tsa[tc] == IEEE_8021QAZ_TSA_ETS):
				msg +="ets, bw: %s%%" % (tcbw[tc])
			elif (tsa[tc] == IEEE_8021QAZ_TSA_STRICT):
				msg += "strict"
			elif (tsa[tc] == IEEE_8021QAZ_TSA_VENDOR):
				msg += "vendor"
			else:
				msg += "unknown"
		except Exception as err:
			pass

		if msg:
			print(msg)

		try:
			for up in tc2up[tc]:
				print(("\t priority:  %s" % up))
		except Exception as err:
			pass

def parse_int(str, min, max, description):
	try:
		v = int(str)

		if (v < min or v > max):
			raise ValueError("%d is not in the range %d..%d" % (v, min, max))

		return v
	except ValueError as e:
		print(("Bad value for %s: %s" % (description, e)))
		parser.print_usage()
		sys.exit(1)

parser = OptionParser(usage="%prog -i <interface> [options]", version="%prog 1.0")

parser.add_option("-f", "--pfc", dest="pfc",
		  help="Set priority flow control for each priority. LIST is " +
			"comma separated value for each priority starting from 0 to 7. " +
			"Example: 0,0,0,0,1,1,1,1 enable PFC on TC4-7",
		  metavar="LIST")
parser.add_option("-p", "--prio_tc", dest="prio_tc",
		  help="maps UPs to TCs. LIST is 8 comma seperated TC numbers. " +
			"Example: 0,0,0,0,1,1,1,1 maps UPs 0-3 to TC0, and UPs 4-7 to " +
			"TC1",
		  metavar="LIST")
parser.add_option("-s", "--tsa", dest="tsa",
		  help="Transmission algorithm for " +
			"each TC. LIST is comma seperated algorithm names for each TC. " +
			"Possible algorithms: strict, etc. Example: ets,strict,ets sets " +
			"TC0,TC2 to ETS and TC1 to strict. The rest are unchanged.",
		  metavar="LIST")
parser.add_option("-t", "--tcbw", dest="tc_bw",
		  help="Set minimal guaranteed %BW for ETS TCs. LIST is comma " +
			"seperated percents for each TC. Values set to TCs that are " +
			"not configured to ETS algorithm are ignored, but must be " +
			"present. Example: if TC0,TC2 are set to ETS, then 10,0,90 " +
			"will set TC0 to 10% and TC2 to 90%. Percents must sum to " +
			"100.",
		  metavar="LIST")
parser.add_option("-r", "--ratelimit", dest="ratelimit",
		  help="Rate limit for TCs (in Gbps). LIST is a comma seperated " +
			"Gbps limit for each TC. Example: 1,8,8 will limit TC0 to " +
			"1Gbps, and TC1,TC2 to 8 Gbps each.",
		  metavar="LIST")
parser.add_option("-d", "--dcbx", dest="dcbx",
		  help="get the dcbx mode(get) or set dcbx mode to firmware controlled(fw) or " +
			"os controlled(os)")
parser.add_option("--trust", dest="trust",
                  help="set priority trust mode to pcp or dscp")
parser.add_option("-i", "--interface", dest="intf",
		  help="Interface name")

parser.add_option("-a", action="store_true", dest="printall", default=False,
		  help="Show all interface's TCs")

(options, args) = parser.parse_args()

if len(args) > 0:
	print("Bad arguments")
	parser.print_usage()
	sys.exit(1)

if (options.intf == None):
	print("Interface name is required")
	parser.print_usage()
	
	sys.exit(1)

ratelimit_path = "/sys/class/net/" + options.intf + "/qos/maxrate"
trust_path = "/sys/class/net/" + options.intf + "/qos/trust"

pfc_en = 0
tsa = [IEEE_8021QAZ_TSA_STRICT, IEEE_8021QAZ_TSA_STRICT,IEEE_8021QAZ_TSA_STRICT,IEEE_8021QAZ_TSA_STRICT,IEEE_8021QAZ_TSA_STRICT,IEEE_8021QAZ_TSA_STRICT,IEEE_8021QAZ_TSA_STRICT,IEEE_8021QAZ_TSA_STRICT]
tc_bw = [0, 0, 0, 0, 0, 0, 0, 0]
prio_tc = [0, 0, 0, 0, 0, 0, 0, 0]
printall = False

ctrl = DcbController(options.intf)

if (options.dcbx != None):
	if (options.dcbx == "os"):
		ctrl.set_dcbx(ctrl.get_dcbx() | DCB_CAP_DCBX_HOST);
	elif (options.dcbx == "fw"):
		ctrl.set_dcbx(0);

	if (ctrl.get_dcbx() & DCB_CAP_DCBX_HOST):
		print("DCBX mode: OS controlled")
	else:
		print("DCBX mode: Firmware controlled")

	sys.exit(1)

try:
	if (not os.path.exists(trust_path)):
		trust_obj = TrustNL(ctrl)
	else:
		trust_obj = TrustSysfs(trust_path)

	trust = trust_obj.get()
	trust = trust_obj.to_string(trust)

	if options.trust:
		trust_obj.set(options.trust)
		trust = options.trust
except:
	print ("Priority trust mode is not supported on your system")
	if options.trust:
		sys.exit(1)
	else:
		trust = "none"

try:
	ratelimit = []
	maxrate = None
	if (not os.path.exists(ratelimit_path)):
		maxrate = MaxrateNL(ctrl)
	else:
		maxrate = MaxrateSysfs(ratelimit_path)
	
	if options.ratelimit:
		for r in options.ratelimit.split(","):
			r = parse_int(r, 0, 1000000, "ratelimit")
	
			ratelimit += [r * 1000 * 1000]
		try:
			maxrate.set(ratelimit)
		except:
			print("Rate limit is not supported on your system!")

	try:
		ratelimit = maxrate.get()
	except:
		print("Rate limit is not supported on your system!")
except:
	if options.ratelimit:
		sys.exit(1)
	else:
		ratelimit = []

try:
	cur_mode = ctrl.get_dcbx()
	expect_mode = DCB_CAP_DCBX_VER_IEEE | DCB_CAP_DCBX_HOST
	if ((cur_mode & expect_mode) != expect_mode):
		ctrl.set_dcbx(cur_mode | expect_mode)

	prio_tc, tsa, tc_bw = ctrl.get_ieee_ets()

	pfc_en = ctrl.get_ieee_pfc()

except:
	print("ETS features are not supported on your system")

if (options.tsa):
	i = 0
	for t in options.tsa.split(","):
		if i >= 8:
			print("Too many items for TSA")
			sys.exit(1)

		if (t == "strict"):
			tsa[i] = IEEE_8021QAZ_TSA_STRICT
		elif (t == 'ets'):
			tsa[i] = IEEE_8021QAZ_TSA_ETS
		else:
			print(("Bad TSA value: ", t))
			parser.print_usage()
			sys.exit(1)
		i += 1

if options.printall:
	printall = True

if options.pfc:
	i = 0
	pfc_en = 0

	for t in options.pfc.split(","):
		if i >= 8:
			print("Too many items for PFC")
			sys.exit(1)

		temp = parse_int(t, 0, 1, "PFC")
		pfc_en |= (temp << i)

		i += 1

	try:
		ctrl.set_ieee_pfc(_pfc_en = pfc_en)
	except OSError as e:
		print(e)
		sys.exit(1)

if options.tc_bw:
	i = 0
	for t in options.tc_bw.split(","):
		if i >= 8:
			print("Too many items for ETS BW")
			sys.exit(1)

		bw = parse_int(t, 0, 100, "ETS BW")

		if tsa[i] == IEEE_8021QAZ_TSA_STRICT and bw != 0:
			print("ETS BW for a strict TC must be 0")
			parser.print_usage()
			sys.exit(1)

		tc_bw[i] = bw
		i += 1

if options.prio_tc:
	i = 0
	for t in options.prio_tc.split(","):
		if i >= 8:
			print("Too many items in UP => TC mapping")
			sys.exit(1)

		prio_tc[i] = parse_int(t, 0, 7, "UP => TC mapping")
		i += 1

if options.tsa or options.tc_bw or options.prio_tc:
	try:
		ctrl.set_ieee_ets(_prio_tc = prio_tc, _tsa = tsa, _tc_bw = tc_bw)
	except OSError as e:
		print(e)
		sys.exit(1)

pretty_print(prio_tc, tsa, tc_bw, ratelimit, pfc_en, trust)
