import cv2
import os

class ImageProcessor:
    def __init__(self):
        pass
    def enlarge_image(self, image, scale_factor=2):
        # 获取原图像的尺寸
        width = int(image.shape[1] * scale_factor)
        height = int(image.shape[0] * scale_factor)
        # 使用Lanczos插值方法进行图像放大
        enlarged_image = cv2.resize(image, (width, height), interpolation=cv2.INTER_LANCZOS4)
        return enlarged_image
 

    def process_image(self, image, rois, output_folder):
        extracted_image_paths = []

        for i, (x1, y1, x2, y2) in enumerate(rois):
            roi = image[y1:y2, x1:x2]
            enlarged_roi = self.enlarge_image(roi)

            if ((x2 - x1) * (y2 - y1)) >= (0.75 * image.shape[1] * image.shape[0]) and ((x2 - x1) > 600 and (y2 - y1) > 600):
               # logger.info(f"Sub-image at index {i} is too large, splitting into smaller sections.")
                if (x2 - x1) > (y2 - y1):  # 横向裁剪
                    mid_x = x1 + (x2 - x1) // 2
                    for j, (start_x, end_x) in enumerate([(x1, mid_x), (mid_x, x2)]):
                        sub_roi = image[y1:y2, start_x:end_x]
                        sub_enlarged_roi = self.enlarge_image(sub_roi)
                        sub_output_path = os.path.join(output_folder, f"formula_{i+1}_{j+1}.png")
                        cv2.imwrite(sub_output_path, sub_enlarged_roi)
                        extracted_image_paths.append(sub_output_path)
                else:  # 竖向裁剪
                    mid_y = y1 + (y2 - y1) // 2
                    for j, (start_y, end_y) in enumerate([(y1, mid_y), (mid_y, y2)]):
                        sub_roi = image[start_y:end_y, x1:x2]
                        sub_enlarged_roi = self.enlarge_image(sub_roi)
                        sub_output_path = os.path.join(output_folder, f"formula_{i+1}_{j+1}.png")
                        cv2.imwrite(sub_output_path, sub_enlarged_roi)
                        extracted_image_paths.append(sub_output_path)
            else:
                # 保存提取的图像
                output_path = os.path.join(output_folder, f"formula_{i+1}.png")
                cv2.imwrite(output_path, enlarged_roi)
                extracted_image_paths.append(output_path)
        
        return extracted_image_paths

# 假设这里有图像和ROI的初始化代码
if __name__ == "__main__":
    processor = ImageProcessor()
    image = cv2.imread("/data/wangtengbo/formula_correct/test_data/SQW@MFFBGE28F2%S%TV]M[0.png")  # 读取图像
    rois = [(50, 50, 700, 1300), (800, 100, 1500, 1600)]  # 假设这里有ROI的定义
    output_folder = "output"
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    extracted_image_paths = processor.process_image(image, rois, output_folder)
    print(extracted_image_paths)
