#! /usr/bin/python3
import rados
import json
import argparse
import subprocess

parser = argparse.ArgumentParser(prog="ct-get-degraded-rbd",
                                 description="Script to show degraded RBD")
parser.add_argument("-p",
                    "--pools",
                    action="store",
                    default="nova,cinder",
                    type=str,
                    help="Pools to analyse separate by comma (default: nova)")
parser.add_argument("-s",
                    "--state",
                    action="store",
                    default="inactive",
                    type=str,
                    help="State of PGs to analyse (default: inactive)")
parser.add_argument("-j",
                    "--json",
                    action="store_true",
                    help="Output in json format")
parser.add_argument("-S",
                    "--servers",
                    action="store",
                    nargs='*',
                    help="List of server to analyse")
args = parser.parse_args()

pools_str = args.pools.split(',')

metadatas = set()
datas = set()
pools = {}
out_json = {}
cluster = rados.Rados(conffile='/etc/ceph/ceph.conf')
cluster.connect()
cmd = json.dumps({'prefix': 'osd lspools', 'format': 'json'})
ret, stdout, _ = cluster.mon_command(cmd, b'')
pools_dict = json.loads(stdout)
for pool in pools_dict:
    if pool['poolname'] in pools_str:
        pools[pool['poolname']] = pool['poolnum']

if args.servers:
    osds = []
    pg_on_osd = {}
    for srv in args.servers:
        cmd = json.dumps({'prefix': 'osd tree-from',
                          'bucket': srv,
                          'format': 'json'})
        ret, stdout, _ = cluster.mon_command(cmd, b'')
        osd = json.loads(stdout)
        osds += osd['nodes'][0]['children']
    for osd in osds:
        pgs = json.loads(subprocess.getoutput(
                f"ceph pg ls-by-osd {osd} -f json"))
        pgst = {x['pgid'] for x in pgs['pg_stats']}
        print(f"\rAnalyze osd.{osd}   ", end='')
        for pg in pgst:
            if pg in pg_on_osd:
                pg_on_osd[pg] += 1
            else:
                pg_on_osd[pg] = 1
    pges = [k for k, v in pg_on_osd.items() if v > 1]
else:
    state = args.state
    cmd = json.dumps({'prefix': "pg dump_stuck",
                      'states': [state],
                      'format': 'json',
                      '': 2})
    ret, stdout, _ = cluster.mon_command(cmd, b'')
    pgs = json.loads(stdout)
    if 'stuck_pg_stats' not in pgs:
        quit()
    pges = [x['pgid'] for x in pgs['stuck_pg_stats']]

print(f"\rTotal affected PG: {len(pges)}\n")

pg_by_pool = {}
for pool, poolid in pools.items():
    pg_by_pool[pool] = [pg for pg in pges if pg.startswith(f"{poolid}.")]

for pool, poolid in pools.items():
    del pg_by_pool[pool][args.max_pgs:]
    tot_pg = len(pg_by_pool[pool])
    print(f"Pool {pool} (Affected PG {tot_pg})")
    rbd_datas = set()
    cpt = 1
    for pg in pg_by_pool[pool]:
        print(f"\r\tAnalyze PG {pg} ({cpt}/{tot_pg})", end='')
        cpt += 1
        obs = json.loads(subprocess.getoutput(f"rados --pgid {pg} ls -f json"))
        objects = [ob['name'] for ob in obs]
        for object in objects:
            name = object.split('.')
            if name[0] == 'rbd_id':
                metadatas.add(name[1])
            elif name[0] == 'rbd_data':
                rbd_datas.add(name[1])
    print("\n\tMapping RBD Data", end='')
    ioctx = cluster.open_ioctx(pool)
    for rbd_data in rbd_datas:
        print('.', end='')
        key = f"id_{rbd_data}"
        with rados.ReadOpCtx(ioctx) as read_op:
            try:
                iter, ret = ioctx.get_omap_vals_by_keys(read_op, (key,))
                ioctx.operate_read_op(read_op, "rbd_directory")
                name = list(iter)[0][1][4:].decode('utf8')
                datas.add(name)
            except Exception:
                pass
    ioctx.close()
    print("\n")
cluster.shutdown()

if args.json:
    out_json['affected metadata'] = list(metadatas)
    out_json['affected data'] = list(datas)
    print(out_json)
else:
    print("Affected metadata:")
    for md in metadatas:
        print(f"\t{md}")
    print("Affected data:")
    for d in datas:
        print(f"\t{d}")
