#!/usr/bin/python
#
# Creates a CSV report from a BAM file.
# This one is modified to include the full taxonomy down to the target tax level.
#
# Output will include the following columns:
# Domain,Kingdom,Phylum,Class,Order,Family,Genus,taxid_at_target_level,read_count,read_percent,sample_id
#
# Porter L, with optimizations from Gemini
#
#  Env is set up on Lovelace
#
# Example:
#   module load samtools 
#   module load python/3.12
#   python bam_csv_fulltax.py -i /path/to/sample.bam -t genus

import os, sys, re, io, argparse, time, csv, psycopg2, pysam, json, subprocess
from tqdm import tqdm

batch_size = 5000

# Index order list of tax levels
taxonomy_order = [
  'Unknown', 'no rank', 'clade', 'cellular root',
  'acellular root', 'realm', 'domain', 'kingdom',
  'subkingdom', 'superphylum', 'phylum', 'subphylum',
  'superclass', 'class', 'subclass', 'infraclass',
  'cohort', 'superorder', 'order', 'suborder', 'infraorder',
  'parvorder', 'superfamily', 'family', 'subfamily',
  'tribe', 'subtribe', 'genus', 'subgenus', 'section',
  'subsection', 'series', 'species group',
  'species subgroup', 'species', 'subspecies', 'varietas',
  'subvariety', 'forma', 'forma specialis', 'isolate',
  'strain', 'biotype', 'genotype', 'morph',
  'serogroup', 'serotype', 'pathogroup'
]

MAJOR_RANKS = [
    'domain', 'kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species'
]

def parse_args():
    parser = argparse.ArgumentParser(description="Report on BAM Contents for Control purposes")
    parser.add_argument('-i','--input', help='Input Bam/Sam file', required=True)
    parser.add_argument('-t', "--tax_level", help="'family', 'genus', 'species', etc.", required=True)
    return parser.parse_args()

try:
    conn = psycopg2.connect(host=lovelace.cluster.earlham.edu, dbname="ncbi", user="fieldsci", password="skalanes")
except Exception as e:
    print(f"Could not connect to database: {e}. Check connection parameters.")
    conn = None # Set to None if connection fails


def get_taxids_for_accessions(accessions, conn):
    accession_map = {}
    if not accessions or not conn:
        return accession_map
    
    try:
        with conn.cursor() as cur:
            cur.execute("""
                SELECT accession_version, tax_id
                FROM accession_taxid
                WHERE accession_version = ANY(%s);
            """, (list(set(accessions)),))
            for row in cur.fetchall():
                accession_map[row[0]] = row[1]
    except Exception as e:
        print(f"[Error] Batch query failed: {e}")
    return accession_map

def get_lineage_from_taxid(taxid, target_tax_level, conn):
    """
    REVISED: Traverses the lineage and collects ALL ranks encountered.
    Returns the target tax_id and a dictionary: {rank: tax_id, ...}
    """
    
    # Dictionary to store the full unpruned lineage
    full_lineage = {} 
    target_taxid = None

    if not conn:
        return target_taxid, {}

    current_taxid = taxid
    
    while True:
        if not current_taxid or current_taxid == 1:
            break
            
        try:
            with conn.cursor() as cur:
                cur.execute("""
                    SELECT tax_id, parent_tax_id, rank
                    FROM nodes
                    WHERE tax_id = %s;
                """, (current_taxid,))
                info = cur.fetchone()
        except Exception:
            break

        if not info:
            break
        tax_id, parent_tax_id, rank = info
        rank = rank.lower()
        if rank not in full_lineage:
            full_lineage[rank] = tax_id
        if rank == target_tax_level.lower():
            target_taxid = tax_id

        current_taxid = parent_tax_id
    
    final_lineage = {}
    
    for major_rank in MAJOR_RANKS:
        if major_rank in full_lineage:
            final_lineage[major_rank] = full_lineage[major_rank]
        elif target_taxid and major_rank in ['genus', 'species'] and major_rank in full_lineage:
            # Special case for species/genus where it might not be a 'major' rank
            final_lineage[major_rank] = full_lineage[major_rank]
            
    if target_taxid:
        return target_taxid, final_lineage
    else:
        return "Other", final_lineage


def tax_lump_batch_modified(taxids, target_tax_level, conn):
    lumped_taxid_map = {}

    for taxid in set(taxids):
        target_taxid, lineage = get_lineage_from_taxid(taxid, target_tax_level, conn)
        
        if target_taxid:
             lumped_taxid_map[taxid] = {'lineage': lineage, 'target_tax_id': target_taxid}
        else:
            lumped_taxid_map[taxid] = {'lineage': {}, 'target_tax_id': "Other"}
    
    return lumped_taxid_map

def process_bam_optimized(input_file, tax_level, conn):
    total_reads = 0
    if tax_level not in taxonomy_order:
        print(f"[Error] Invalid taxonomic level: {tax_level}")
        return {}, 0
    
    try:
        target_rank_index = MAJOR_RANKS.index(tax_level)
    except ValueError:
        print(f"[Warning] Target rank '{tax_level}' is not in MAJOR_RANKS list. Defaulting to 'genus'.")
        target_rank_index = MAJOR_RANKS.index('genus')

    taxid_cache = {}
    lumped_lineage_cache = {} 
    accession_batch = []
    counts_data = {} 
    lineage_lookup = {} 

    try:
        with pysam.AlignmentFile(input_file, "rb") as bamfile:
            for read in tqdm(bamfile.fetch(until_eof=True), unit="read", unit_scale=True, desc="Parsing BAM file"):
                total_reads += 1
                acc = read.reference_name
                accession_batch.append(acc)
                if len(accession_batch) >= batch_size:
                    process_batch(accession_batch, taxid_cache, lumped_lineage_cache, tax_level, conn, counts_data, lineage_lookup)
                    accession_batch = []
    
    except FileNotFoundError:
        print(f"[Error] The file {input_file} was not found.")
        return {}, 0
    except Exception as e:
        print(f"[Error] An unexpected error occurred: {e}")
        return {}, 0
    if accession_batch:
        process_batch(accession_batch, taxid_cache, lumped_lineage_cache, tax_level, conn, counts_data, lineage_lookup)
    
    return counts_data, total_reads, lineage_lookup


def process_batch(accession_batch, taxid_cache, lumped_lineage_cache, tax_level, conn, counts_data, lineage_lookup):
    new_accessions = list(set([acc for acc in accession_batch if acc not in taxid_cache]))
    if new_accessions:
        taxid_map = get_taxids_for_accessions(new_accessions, conn)
        taxid_cache.update(taxid_map)
    taxids_to_lump = list(set([taxid_cache.get(acc, "Unknown") for acc in accession_batch]))
    taxids_to_lump = [tax_id for tax_id in taxids_to_lump if tax_id != "Unknown"] 

    new_taxids_to_lump = [tax_id for tax_id in taxids_to_lump if tax_id not in lumped_lineage_cache]
    if new_taxids_to_lump:
        lumped_map = tax_lump_batch_modified(new_taxids_to_lump, tax_level, conn)
        lumped_lineage_cache.update(lumped_map)

    for accession in accession_batch:
        taxid = taxid_cache.get(accession, "Unknown")
        lineage_result = lumped_lineage_cache.get(taxid, {'lineage': {}, 'target_tax_id': "Unknown"}) 
        target_taxid = lineage_result['target_tax_id']
        counts_data[target_taxid] = counts_data.get(target_taxid, 0) + 1
        if target_taxid not in ["Unknown", "Other"]:
            lineage_lookup[target_taxid] = lineage_result['lineage']
        elif target_taxid == "Unknown" and target_taxid not in lineage_lookup:
            lineage_lookup["Unknown"] = {rank: "Unknown" for rank in MAJOR_RANKS if taxonomy_order.index(rank) <= taxonomy_order.index(tax_level)}
        elif target_taxid == "Other" and target_taxid not in lineage_lookup:
            lineage_lookup["Other"] = {rank: "Other" for rank in MAJOR_RANKS if taxonomy_order.index(rank) <= taxonomy_order.index(tax_level)}


def get_latin_from_taxid(tax_id, conn):
    if not conn:
        return "Unknown (No DB)"
    with conn.cursor() as cur:
        cur.execute("SELECT name_txt FROM scientific_names WHERE tax_id = %s;", (tax_id,))
        row = cur.fetchone()
        return row[0] if row else f"TaxID:{tax_id}"


def get_latin_names_batch(tax_ids, conn):
    name_map = {}
    if not tax_ids or not conn:
        return name_map

    numeric_tax_ids = [tid for tid in tax_ids if isinstance(tid, int) or str(tid).isdigit()]
    
    if not numeric_tax_ids:
        return name_map

    try:
        with conn.cursor() as cur:
            # Convert to list for the WHERE IN clause
            cur.execute("""
                SELECT tax_id, name_txt
                FROM scientific_names
                WHERE tax_id = ANY(%s);
            """, (numeric_tax_ids,))
            for row in cur.fetchall():
                name_map[row[0]] = row[1]
    except Exception as e:
        print(f"[Error] Batch name query failed: {e}")
        
    final_name_map = {}
    for tid in tax_ids:
        if tid in name_map:
            final_name_map[tid] = name_map[tid]
        elif str(tid).isdigit() and int(tid) in name_map:
            final_name_map[tid] = name_map[int(tid)]
        elif tid in ["Unknown", "Other"]:
            final_name_map[tid] = tid
        elif tid in numeric_tax_ids:
            final_name_map[tid] = f"TaxID:{tid}"
            
    return final_name_map


def print_report(data, total_reads, input_file, output_file, exec_time, target_level):
    print(f"Input File: {input_file}")
    print(f"Target Level: {target_level}")
    print(f"Output File: {output_file}")
    print(f"Processing Time: {exec_time} sec")
    print(f"  TOTAL READS: {total_reads}")
    
    unknown_reads = data.get("Unknown", 0)
    other_reads = data.get("Other", 0)
    classified_reads = total_reads - unknown_reads - other_reads
    taxid_count = len(data)

    print(f"  TOTAL {target_level.upper()} IDs: {taxid_count}")

    unknown_reads_percent = round(unknown_reads / total_reads * 100, 2) if total_reads else 0
    other_reads_percent = round(other_reads / total_reads * 100, 2) if total_reads else 0
    classified_reads_percent = round(classified_reads / total_reads * 100, 2) if total_reads else 0

    print(f"  Unknown Reads: {unknown_reads} ({unknown_reads_percent}%)")
    print(f"  Other Reads: {other_reads} ({other_reads_percent}%)")
    print(f"  Classified Reads: {classified_reads} ({classified_reads_percent}%)")


def write_csv(data, output_file, sample_id, total_reads, lineage_lookup, target_level):
    target_rank_index = MAJOR_RANKS.index(target_level.lower())
    header_ranks = [rank.capitalize() for rank in MAJOR_RANKS if MAJOR_RANKS.index(rank) <= target_rank_index]
    csv_header = header_ranks + ['taxid_at_target_level', 'read_count', 'read_percent', 'sample_id']

    all_taxids_to_name = set()
    for target_taxid, lineage in lineage_lookup.items():
        if target_taxid not in ["Unknown", "Other"]:
            all_taxids_to_name.add(int(target_taxid))
            for taxid in lineage.values():
                 # Handle cases where lineage might not be complete
                if taxid not in ["Unknown", "Other"]:
                    all_taxids_to_name.add(int(taxid))

    # Fetch all scientific names in one batch
    latin_name_map = get_latin_names_batch(list(all_taxids_to_name), conn)

    with open(output_file, 'w', newline='') as file:
        csvfile = csv.writer(file)
        csvfile.writerow(csv_header)
        sorted_data = sorted(data.items(), key=lambda item: item[1], reverse=True)
        
        for target_taxid, count in sorted_data:
            row_data = []
            lineage = lineage_lookup.get(target_taxid, {})
            for rank in MAJOR_RANKS:
                if MAJOR_RANKS.index(rank) <= target_rank_index:
                    taxid = lineage.get(rank, "Other" if target_taxid == "Other" else "Unknown")
                    name = latin_name_map.get(taxid, taxid)
                    if taxid in ["Unknown", "Other"]:
                        name = taxid
                    
                    row_data.append(name)

            read_percent = round(count / total_reads * 100, 4)
            row_data.append(target_taxid)
            row_data.append(count)
            row_data.append(read_percent)
            row_data.append(sample_id)
            
            csvfile.writerow(row_data)

if __name__ == "__main__":
    start_time = time.time()
    
    args = parse_args()
    sample_id = args.input.split("/")[-1].split(".")[0]
    target_level = args.tax_level.lower()
    output_file = f"{sample_id}-{target_level}-fulltax.csv"
    input_file = args.input
    data, total_reads, lineage_lookup = process_bam_optimized(input_file, target_level, conn)
    
    if total_reads > 0:
        write_csv(data, output_file, sample_id, total_reads, lineage_lookup, target_level)
        exec_time = round(time.time() - start_time, 2)
        print("Done.\n\n")
        print_report(data, total_reads, input_file, output_file, exec_time, target_level)
    else:
        print(f"No reads found in {input_file} or processing failed.")
        
    if conn:
        conn.close()