import xml.etree.ElementTree as ET
import cv2
import random

def parse_xml(file_path):
    """
    解析XML文件，提取公式框和文本框的坐标。
    
    参数:
    file_path (str): XML文件路径。

    返回:
    tuple: 公式框列表和文本框列表，格式为[(xmin, ymin, xmax, ymax), ...]。
    """
    tree = ET.parse(file_path)
    root = tree.getroot()

    formula_boxes = []
    text_boxes = []

    for obj in root.findall('object'):
        name = obj.find('name').text
        xmin = int(obj.find('bndbox/xmin').text)
        ymin = int(obj.find('bndbox/ymin').text)
        xmax = int(obj.find('bndbox/xmax').text)
        ymax = int(obj.find('bndbox/ymax').text)
        box = (xmin, ymin, xmax, ymax)

        if name == 'formula':
            formula_boxes.append(box)
        elif name == 'text':
            text_boxes.append(box)
    
    return formula_boxes, text_boxes

def find_closest_text_boxes(formula_box, text_boxes):
    """
    找到给定公式框最近的左、右、上、下文本框。

    参数:
    formula_box (tuple): 公式框的坐标(xmin, ymin, xmax, ymax)。
    text_boxes (list): 文本框列表，格式为[(xmin, ymin, xmax, ymax), ...]。

    返回:
    tuple: 最近的左、右、上、下文本框的坐标，如果没有则为None。
    """
    left_text = right_text = top_text = bottom_text = None
    left_dist = right_dist = top_dist = bottom_dist = float('inf')
    xmin_f, ymin_f, xmax_f, ymax_f = formula_box
        
    for text_box in text_boxes:
        xmin_t, ymin_t, xmax_t, ymax_t = text_box

        if check_overlap(formula_box, text_box):
            continue

        if xmax_t < xmin_f and xmin_f - xmax_t < left_dist:
            left_dist = xmin_f - xmax_t
            left_text = text_box
        if xmin_t > xmax_f and xmin_t - xmax_f < right_dist:
            right_dist = xmin_t - xmax_f
            right_text = text_box
        if ymax_t < ymin_f and ymin_f - ymax_t < top_dist:
            top_dist = ymin_f - ymax_t
            top_text = text_box
        if ymin_t > ymax_f and ymin_t - ymax_f < bottom_dist:
            bottom_dist = ymin_t - ymax_f
            bottom_text = text_box

    return left_text, right_text, top_text, bottom_text

def calculate_iou(box1, box2):
    """
    计算两个框的交并比（IoU）。

    参数:
    box1 (tuple): 第一个框的坐标(xmin, ymin, xmax, ymax)。
    box2 (tuple): 第二个框的坐标(xmin, ymin, xmax, ymax)。

    返回:
    float: 两个框的交并比（IoU）。
    """
    xmin1, ymin1, xmax1, ymax1 = box1
    xmin2, ymin2, xmax2, ymax2 = box2

    # 计算交集坐标
    inter_xmin = max(xmin1, xmin2)
    inter_ymin = max(ymin1, ymin2)
    inter_xmax = min(xmax1, xmax2)
    inter_ymax = min(ymax1, ymax2)

    # 计算交集面积
    inter_area = max(0, inter_xmax - inter_xmin + 1) * max(0, inter_ymax - inter_ymin + 1)

    # 计算每个框的面积
    box1_area = (xmax1 - xmin1 + 1) * (ymax1 - ymin1 + 1)
    box2_area = (xmax2 - xmin2 + 1) * (ymax2 - ymax2 + 1)

    # 计算并集面积
    union_area = box1_area + box2_area - inter_area

    # 计算IoU
    iou = inter_area / union_area

    return iou

def check_overlap(box1, box2, iou_threshold=0.1):
    """
    检查两个框是否有重叠，并且重叠的IoU在一定范围内是可以接受的。

    参数:
    box1 (tuple): 第一个框的坐标(xmin, ymin, xmax, ymax)。
    box2 (tuple): 第二个框的坐标(xmin, ymin, xmax, ymax)。
    iou_threshold (float): 可接受的IoU阈值。

    返回:
    bool: 如果两个框的IoU在可接受范围内则返回True，否则返回False。
    """
    iou = calculate_iou(box1, box2)
    return iou >= iou_threshold

def check_no_overlap_with_others(expanded_box, formula_boxes, current_box):
    """
    检查扩展后的框是否与其他公式框重叠。

    参数:
    expanded_box (tuple): 扩展后的框的坐标(xmin, ymin, xmax, ymax)。
    formula_boxes (list): 所有公式框的列表。
    current_box (tuple): 当前正在处理的公式框。

    返回:
    bool: 如果扩展后的框与其他公式框没有重叠则返回True，否则返回False。
    """
    for other_box in formula_boxes:
        if other_box != current_box and check_overlap(expanded_box, other_box):
            return False
    return True

def expand_formula_boxes(formula_boxes, text_boxes):
    """
    扩展公式框，使其与最近的文本框融合。

    参数:
    formula_boxes (list): 公式框列表，格式为[(xmin, ymin, xmax, ymax), ...]。
    text_boxes (list): 文本框列表，格式为[(xmin, ymin, xmax, ymax), ...]。

    返回:
    list: 扩展后的公式框列表。
    """
    expanded_formula_boxes = []

    for formula_box in formula_boxes:
        xmin_f, ymin_f, xmax_f, ymax_f = formula_box

        left_text, right_text, top_text, bottom_text = find_closest_text_boxes(formula_box, text_boxes)

        # 扩展左边
        if left_text:
            xmin_t, ymin_t, xmax_t, ymax_t = left_text
            new_xmin_f = xmin_t
            expanded_box = (new_xmin_f, min(ymin_f, ymin_t), xmax_f, max(ymax_f, ymax_t))
            if check_no_overlap_with_others(expanded_box, formula_boxes, formula_box):
                xmin_f = new_xmin_f
                ymin_f = min(ymin_f, ymin_t)
                ymax_f = max(ymax_f, ymax_t)

        # 扩展右边
        if right_text:
            xmin_t, ymin_t, xmax_t, ymax_t = right_text
            new_xmax_f = xmax_t
            expanded_box = (xmin_f, min(ymin_f, ymin_t), new_xmax_f, max(ymax_f, ymax_t))
            if check_no_overlap_with_others(expanded_box, formula_boxes, formula_box):
                xmax_f = new_xmax_f
                ymin_f = min(ymin_f, ymin_t)
                ymax_f = max(ymax_f, ymax_t)

        # 扩展上边
        if top_text:
            xmin_t, ymin_t, xmax_t, ymax_t = top_text
            new_ymin_f = ymin_t
            expanded_box = (min(xmin_f, xmin_t), new_ymin_f, max(xmax_f, xmax_t), ymax_f)
            if check_no_overlap_with_others(expanded_box, formula_boxes, formula_box):
                ymin_f = new_ymin_f
                xmin_f = min(xmin_f, xmin_t)
                xmax_f = max(xmax_f, xmax_t)

        # 扩展下边
        if bottom_text:
            xmin_t, ymin_t, xmax_t, ymax_t = bottom_text
            new_ymax_f = ymax_t
            expanded_box = (min(xmin_f, xmin_t), ymin_f, max(xmax_f, xmax_t), new_ymax_f)
            if check_no_overlap_with_others(expanded_box, formula_boxes, formula_box):
                ymax_f = new_ymax_f
                xmin_f = min(xmin_f, xmin_t)
                xmax_f = max(xmax_f, xmax_t)

        expanded_formula_boxes.append((xmin_f, ymin_f, xmax_f, ymax_f))

    return expanded_formula_boxes

def draw_boxes_on_image(image_path, boxes, output_path):
    """
    在图像上绘制框并保存结果图像。

    参数:
    image_path (str): 输入图像的路径。
    boxes (list): 需要绘制的框的列表，格式为[(xmin, ymin, xmax, ymax), ...]。
    output_path (str): 保存结果图像的路径。
    """
    image = cv2.imread(image_path)

    def random_color():
        return (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))

    for box in boxes:
        xmin, ymin, xmax, ymax = box
        color = random_color()
        cv2.rectangle(image, (xmin, ymin), (xmax, ymax), color, 2)

    cv2.imwrite(output_path, image)
    print(f"扩展后的图像已保存到 {output_path}")

def main(xml_file_path, image_file_path, output_image_path):
    """
    主函数，执行XML解析、公式框扩展和绘制框的操作。

    参数:
    xml_file_path (str): XML文件的路径。
    image_file_path (str): 输入图像的路径。
    output_image_path (str): 保存结果图像的路径。
    """
    formula_boxes, text_boxes = parse_xml(xml_file_path)
    expanded_formula_boxes = expand_formula_boxes(formula_boxes, text_boxes)
    
    for box in expanded_formula_boxes:
        print(box)
    
    draw_boxes_on_image(image_file_path, expanded_formula_boxes, output_image_path)

if __name__ == "__main__":
    main('1.xml', '1.png', './expanded_image.png')
