# created by Nina Marthe 2023 - nina.marthe@ird.fr
# licensed under MIT

import subprocess
from .intersect import run_intersect,load_intersect,Features
from .argparser import read_args,arg
from .load_gfa import *
#from inference import *
from .graph_annot import graph_gff,graph_gaf
from .setup_transfer import transfer_on_target
from pathlib import Path

def main():

    args=arg()
    read_args(args) 

    # intersect between gff and all the bed files from the source genome given by seg_coord
    run_intersect(args)

    intersect=Path(args.outdir.joinpath("intersect")).resolve() # intersect is in the output directory
    gfa=args.graph
    load_intersect(intersect.as_posix(),args.verbose)

    command=f"rm {intersect.as_posix()}" # delete intersect file
    subprocess.run(command,shell=True,timeout=None)

    segments=args.segment_coordinates_path.joinpath("segments.txt")

    # outputs the gff and gaf of the graph
    if args.graph_gff or args.graph_gaf:
        print("\n")
        if args.graph_gff:
            out_graph_gff=args.outdir.joinpath(gfa.stem).as_posix()+".gff" # todo : what if there is several suffixes and the last one isn't '.gfa' ?
            graph_gff(out_graph_gff,args.verbose)
        if args.graph_gaf:
            seg_size=get_segments_length(segments,False)
            out_graph_gaf=args.outdir.joinpath(gfa.stem).as_posix()+".gaf"
            graph_gaf(out_graph_gaf,seg_size,args.verbose)


    # todo : what about pav ??
    if args.annotation or args.variation or args.alignment:
        if args.verbose:
            print('\n')

        get_target_genomes(args)

        # initialize the pav matrix
        pav_dict=init_pav(args,Features)

        # build a dictionnary with the segment sizes to compute the coverage and id
        if not args.graph_gaf:
            seg_size=get_segments_length(segments,args.verbose)
        
        genome_index=0
        for target_genome in args.target:
            print(f'\n{target_genome} transfer :')
            # create directory to store output files
            genome_dir=Path(args.outdir.joinpath(target_genome)).resolve()
            genome_dir.mkdir(exist_ok=True)

            # get list of files in seg_coord
            segment_coord_files=list(args.segment_coordinates_path.glob(f"*{target_genome}*.bed"))

            # create dictionnaries with paths and segments positions.
            print(f'     Loading the walks for the genome {target_genome}')
            walks_path=args.segment_coordinates_path.joinpath("walks.txt").as_posix()
            target_genome_paths=get_paths(walks_path,target_genome,args.haplotype)
            if not args.verbose:
                print("     Loading the segments coordinates")
            segments_on_target_genome={}
            for file in segment_coord_files:
                genome_name=get_genome_name(args.target,file.name,args.haplotype)
                if genome_name==target_genome :
                    if args.verbose:
                        print(f'     Loading the segments coordinates for the path {file.stem}')
                    get_segments_positions_on_genome(file.as_posix(),segments_on_target_genome)

            list_feat_absent=[]
            # do the annotation transfer (or var/aln)
            transfer_on_target(segments,genome_dir,target_genome,target_genome_paths,list_feat_absent,seg_size,args,segments_on_target_genome)
            # if pav matrix is asked, add the information of this transfer on the matrix
            if args.pav_matrix:
                for feat in list_feat_absent:
                    if Features[feat].type=="gene":
                        pav_dict[feat][genome_index]=0
                genome_index+=1
        
        print_pav_matrix(pav_dict,args)


def init_pav(args,Features):
    # if a pav matrix is asked, create dictionnary to store the features, and for each feature the value for all the target genomes.
    pav_dict={}
    if args.pav_matrix:
        pav_line=len(args.target)*[1]
        pav_dict["gene_id"]=list(args.target)
        for feat in Features.keys():
            if Features[feat].type=='gene':
                pav_dict[feat]=pav_line.copy()
    return pav_dict

def print_pav_matrix(pav_dict,args):
    # print the pav matrix
    if args.pav_matrix:
        print('\nGeneration of the presence-absence matrix for the transfered genes')
        pav_output=''
        for line in pav_dict:
            pav_output+=line
            for field in pav_dict[line]:
                pav_output+="\t"+str(field)
            pav_output+="\n"

        out_pav=args.outdir.joinpath("PAV_matrix.txt")
        with open(out_pav,'w') as file_out_pav:
            file_out_pav.write(pav_output)