import random
import math
import numpy as np
 

#def generate_chr_for_one_ind (chromosome_size, marker_density, err_rate, mean_depth, max_depth, markers_positions, conversion_factor):
def generate_chr_for_one_ind (mean_depth, markers_positions, conversion_factor, errA, errB):
    '''
    Generate a chromosome for a given size, density of markers, and depth (follwing a gaussian distrib with a mean_depth and sd_depth).
    
    Args:
        mean_depth (float): The mean sequencing depth of the chromosome
        markers_positions (lst) : liste of marker positions (picked on a grid)
        conversion_factor (float) : Value of the bp per cM conversion (bp chromosome size / cM chromosome size)
        errA & errB (float): Respectively the error rate of observing a B whereas the genotype is truely a A and a A whereas the genotype is truely a B
    
    Returns:
        segment: A chromosome with each position having a given genotype with a given site depth and allele depth
        segment_error : A chromosome with each position having a given genotype with a given site depth and allele depth considering the error rates errA and errB
    '''
    
    # Create an empty list to store the chromosome
    segment = []
    # Create an empty list to store the chromosome with error
    segment_error = []
    breakpoints = []

    # Iterate through the list of markers positions (fixed)
    previousMarkerGenotype = ""
    for i in range(0, len(markers_positions)):
        recomb1 = False
        recomb2 = False
        # If the current marker is not the first one
        if i > 0:
            # Calculate the interval between the current marker and the previous one
            #using the Kosambi inverse function to estimate the chance of recombination
            IntervalWithPreviousMarker = markers_positions[i] - markers_positions[i-1]
            tempcM = 2 * (IntervalWithPreviousMarker / 100) / conversion_factor         ## Conversion factor is used to have the bp per cM 
            IntervalWithPreviousMarkerInRF = 0.5 * ((math.exp(tempcM) - math.exp(-tempcM)) / (math.exp(tempcM) + math.exp(-tempcM)))
            rnd = random.random()

            # If the random number is greater than the probability of a recombination a recombination occurs
            # Done for genotype 1
            if rnd > 1 - IntervalWithPreviousMarkerInRF: 
                recomb1 = True
                # If the previous marker is A, change the genotype to B
                if genotype1 == "A":
                    genotype1 = "B"
                else :
                    genotype1 = "A"
                    
            # Done for genotype 2
            rnd = random.random()
            if rnd > 1 - IntervalWithPreviousMarkerInRF:
                recomb2 = True
                # If the previous marker is B, change the genotype to A
                if genotype2 == "A":
                    genotype2 = "B"
                else :
                    genotype2 = "A"
                    
        # If the current marker is the first one, Set the two first genotypes (random).
        else:
            rnd = random.random()
            # If the random number is greater than 0.5
            if rnd > 0.5:
                genotype1 = "A"
            else : 
                genotype1 = "B"

            rnd = random.random()
            if rnd > 0.5:
                genotype2 = "B"
            else : 
                genotype2 = "A"
            
            # initialize previousMarkerGenotype
            if genotype1 == "A" and genotype2 == "A":
                previousMarkerGenotype = "A"
            elif genotype1 == "B" and genotype2 == "B":
                previousMarkerGenotype = "B"
            else :
                previousMarkerGenotype = "H"

        #Breakpoints
        if recomb1 or recomb2:
            #this is a breakpoint [beforeBkpAfter, transition]
            transition = previousMarkerGenotype + " => "

            if genotype1 == "A" and genotype2 == "A":
                previousMarkerGenotype = "A"
            elif genotype1 == "B" and genotype2 == "B":
                previousMarkerGenotype = "B"
            else :
                previousMarkerGenotype = "H"
            
            transition = transition + previousMarkerGenotype
        
            breakpoints.append([i,transition]) 

        # Calculate the site depth of the current marker    
        g = np.random.poisson(mean_depth)
        # If the depth of the current marker is less than 0 (safeguard)
        if g < 0:
            g = 0
        # Round the depth of the current marker
        site_depth = round(g)
        # Initialize x and y (with and without error)
        x = 0
        y = 0
        x_error = 0
        y_error = 0
        
        # If the depth of the current marker is 0, genotype is Missing Data
        if site_depth == 0:            
            genotype = "./."
            genotype_error = "./."
        else :
            # If the site is homozygous A (ref, 0/0)
            if genotype1 == "A" and genotype2 == "A":
                x = site_depth
                y = 0
                genotype = "0/0"
                x_error = 0
                y_error = 0
                # Considering the sequencing and mapping error at each site for a given site depth
                for j in range(0,site_depth):
                    rnd = random.random()
                    # If the random number is smaller the error rate, we increase the wronge allele depth
                    if rnd < errA:
                        x_error = x_error + 0
                        y_error = y_error + 1
                    else : 
                        x_error = x_error + 1
                        y_error = y_error + 0
                if y_error == site_depth:
                    genotype_error = "1/1"
                elif y_error > 0:
                    genotype_error = "0/1"
                else :
                    genotype_error = "0/0"

            # If the current marker is B and the previous marker is B, genotype if homozygote B (alt, 1/1)
            elif genotype1 == "B" and genotype2 == "B":
                x = 0
                y = site_depth
                genotype = "1/1"
                x_error = 0
                y_error = 0
                # Considering the sequencing and mapping error at each site for a given site depth
                for j in range(0,site_depth):
                    rnd = random.random()
                    # If the random number is smaller the error rate, we increase the wronge allele depth
                    if rnd < errB:
                        x_error = x_error + 1
                        y_error = y_error + 0
                    else : 
                        x_error = x_error + 0
                        y_error = y_error + 1
                if x_error == site_depth:
                    genotype_error = "0/0"
                elif x_error > 0:
                    genotype_error = "0/1"
                else :
                    genotype_error = "1/1"
                    
            # If the current marker is neither A nor B, genotype if heterozygous H (0/1)
            else:
                # Generate a random number between 0 and the depth of the current marker
                x = random.randint(0,site_depth)
                y = site_depth - x
                # Error of A and B compensate themselves in heterozygous site so no need to change the x_error and y_error
                x_error = x
                y_error = y
                # If the depth of x of the current marker is 0, it's seen as homozygous site (alt, 1/1)
                if x == 0:
                    genotype = "1/1"
                    genotype_error = "1/1"
                # If the depth of x of the current marker is 0, it's seen as homozygous site (ref, 0/0)
                elif y == 0:
                    genotype = "0/0"
                    genotype_error = "0/0"
                # If the depth x and y of the current marker is not 0, it's seen as heterozygous (0/1)
                else :
                    genotype = "0/1"
                    genotype_error = "0/1"  
           
        #print("reads",genotype1, genotype2)
        #print(genotype)
        # Append the genotype and the depth of the current marker to the segment list
        finalGenotype = str(genotype) + ":" + str(site_depth) + ":" + str(x) + "," + str(y) + ":.:.:.:.:."        
        segment.append(finalGenotype)
        
        # Append the genotype and the depth of the current marker to the segment_error list
        finalGenotype_error = str(genotype_error) + ":" + str(site_depth) + ":" + str(x_error) + "," + str(y_error) + ":.:.:.:.:."
        segment_error.append(finalGenotype_error)

    return segment, segment_error, breakpoints

# Generate a list of individuals
def generate_individuals (nb_individuals, chromosome_size, marker_density, mean_depth, conversion_factor, errA, errB):
    matrix = []
    matrix_error = []

    # Calculate the number of marker requiered
    markers_nb = round(chromosome_size * marker_density)
    
    # Generate a sorted list of markers positions
    #markers_positions = sorted(random.sample(range(chromosome_size),size))
    markers_positions = np.linspace(1, chromosome_size, markers_nb, dtype="int")

    with open('Breakpoints_3x.csv', 'w') as breakpointFile:
        breakpointFile.write(",".join(["sample","average_bkp_position", "bkp_start_position", "bkp_stop_position", "transitionType"]) + "\n");
        for i in range(nb_individuals):
            print("individual " + str(i+1))
            result = generate_chr_for_one_ind(mean_depth, markers_positions, conversion_factor, errA, errB)
            matrix.append(result[0])
            matrix_error.append(result[1])
            for bkp in result[2]:
                startBkp = markers_positions[bkp[0] - 1]
                stopBkp = markers_positions[bkp[0]]
                averageBkpPosition = startBkp + round((stopBkp - startBkp)/2)
                bkp_row = [str(i), str(averageBkpPosition), str(startBkp), str(stopBkp), bkp[1]]
                breakpointFile.write(",".join(bkp_row) + "\n");
    
    return matrix, matrix_error, markers_positions

# # Check if a row is correct
# def is_correct (row):
#     l = len(row)

#     count_h = row.count("H") / l
#     count_a = row.count("A") / l
#     count_b = row.count("B") / l
    
#     return [count_a, count_b, count_h, 0.45 <= count_h <= 0.55 and 0.2 <= count_b <= 0.3 and 0.2 <= count_a <= 0.3]

# Set the parameters for the script
nb_individuals = 100
chromosome_size = 44000000
cMsize = 180    ## Size of genetic map
conversion_factor = chromosome_size/cMsize   ## Corresponds to a bpPercM conversion ! Needs to be fixed... Does not produce the correct division!
marker_density = 0.0055
mean_depth = 1.5
max_depth = 3 #TODO (better estimate of max?)
errA = 0.005
errB = 0.005

## Generate the pop and associate the results to the matrixes with and without errors
pop = generate_individuals(nb_individuals, chromosome_size, marker_density, mean_depth, conversion_factor, errA, errB)
matrix = pop[0]
matrix_error = pop[1]
markers_positions = pop[2]

# Header to add to the VCF for it to be recognized as such file.    
header = [
"##fileformat=VCFv4.1\n",
"##fileDate=20090805\n",
"##source=myImputationProgramV3.1\n",
"##reference=/shared/projects/recombinationlandscape/REFERENCE_GENOME/Osat_Azucena_AGI_chrOK_uniline_WithNIPBARorganelles.fasta\n",
"##contig=<ID=chr01,length=44011168>\n",
"##phasing=partial\n",
"##INFO=<ID=NS,Number=1,Type=Integer,Description=\"Number of Samples With Data\">\n",
"##INFO=<ID=DP,Number=1,Type=Integer,Description=\"Total Depth\">\n",
"##INFO=<ID=AF,Number=A,Type=Float,Description=\"Allele Frequency\">\n",
"##INFO=<ID=AA,Number=1,Type=String,Description=\"Ancestral Allele\">\n",
"##INFO=<ID=DB,Number=0,Type=Flag,Description=\"dbSNP membership, build 129\">\n",
"##INFO=<ID=H2,Number=0,Type=Flag,Description=\"HapMap2 membership\">\n",
"##FILTER=<ID=q10,Description=\"Quality below 10\">\n",
"##FILTER=<ID=s50,Description=\"Less than 50% of samples have data\">\n",
"##FORMAT=<ID=GT,Number=1,Type=String,Description=\"Genotype\">\n",
"##FORMAT=<ID=GQ,Number=1,Type=Integer,Description=\"Genotype Quality\">\n",
"##FORMAT=<ID=DP,Number=1,Type=Integer,Description=\"Read Depth\">\n",
"##FORMAT=<ID=HQ,Number=2,Type=Integer,Description=\"segment Quality\">\n"
]


## Write the VCF files (with and without errors)
with open('test_3x.vcf', 'w') as file:
    # Write the header of the file
    file.writelines(header)
    
    # Write the header of the file
    file.write("\t".join(["#CHROM", "POS", "ID", "REF", "ALT", "QUAL", "FILTER", "INFO", "FORMAT"] + [str(n) for n in range(nb_individuals)] + ["Parent1", "Parent2"]) + "\n")
    
    # Iterate over each row in the table
    for m in range(0, len(matrix[0])):
        row = ["chr01", str(markers_positions[m]), ".", "A", "T", ".", ".", ".", "GT:DP:AD:RO:QR:AO:QA:GL"]
        for i in range(0, len(matrix)):
            row.append(matrix[i][m])
        row.append("0/0:"+ str(round(mean_depth)) + ":" + str(round(mean_depth)) + ",0:.:.:.:.")
        row.append("1/1:"+ str(round(mean_depth)) + ":0," + str(round(mean_depth)) + ":.:.:.:.")
        file.write("\t".join(row) + "\n");

with open('test_error_3x.vcf', 'w') as file:
    # Write the header of the file
    file.writelines(header)
    
    # Write the header of the file
    file.write("\t".join(["#CHROM", "POS", "ID", "REF", "ALT", "QUAL", "FILTER", "INFO", "FORMAT"] + [str(n) for n in range(nb_individuals)] + ["Parent1", "Parent2"]) + "\n")
    
    # Iterate over each row in the table
    for m in range(0, len(matrix_error[0])):
        row = ["chr01", str(markers_positions[m]), ".", "A", "T", ".", ".", ".", "GT:DP:AD:RO:QR:AO:QA:GL"]
        for i in range(0, len(matrix_error)):
            row.append(matrix_error[i][m])
        row.append("0/0:"+ str(round(mean_depth)) + ":" + str(round(mean_depth)) + ",0:.:.:.:.")
        row.append("1/1:"+ str(round(mean_depth)) + ":0," + str(round(mean_depth)) + ":.:.:.:.")
        file.write("\t".join(row) + "\n");