Commit 905b999b by unknown

init

parent 245877af
图标题匹配
图正文匹配
\ No newline at end of file
图文匹配-标题
图文匹配-正文
config.py
CHECK_IMAGE_CAPTION=True # 一、图像和标题匹配 开关
CHECK_IMAGE_CAPTION_VLM=True # 图像和标题匹配 提示词开关
CHECK_IMAGE_CONTEXT=False # 二、图像和正文匹配 开关
CHECK_IMAGE_CONTEXT_VLM=False # 图像和正文匹配 提示词开关
服务启动
```
# cd /nfs/liuxin/work/Image_TextTitle_Matching
# conda activate text_check
# nohup python -u main.py > main_image_title.log 2>&1 &
# tail -f main_image_title.log
# ss -ntlp | grep 29500
```
import os
from enum import Enum
ENV = os.getenv('AI_CORRECTOR_ENV', 'PRO')
DATA_PDFPARSE_DIR = os.environ.get('DATA_PDFPARSE_DIR_PRO','/nfs/liuxin/work/Image_TextTitle_Matching/Image_Text_Matching_Server_Pro') # 读取环境变量
#DATA_PDFPARSE_DIR='/home/share/shucshqyfzyxgsi/home/dcai/dcg-ai-logs/formula'
USE_TYPE = os.environ.get('USE_TYPE')
DATA_LOGS_DATA = os.path.join(DATA_PDFPARSE_DIR,'logs')
if not os.path.exists(DATA_LOGS_DATA):
os.makedirs(DATA_LOGS_DATA)
CHECK_IMAGE_CAPTION=True # 一、图像和标题匹配 开关
CHECK_IMAGE_CAPTION_VLM=True # 图像和标题匹配 提示词开关
CHECK_IMAGE_CONTEXT=False # 二、图像和正文匹配 开关
CHECK_IMAGE_CONTEXT_VLM=False # 图像和正文匹配 提示词开关
OBS_TOKEN="dcg-4c1e3a7f4fcd415e8c93151ff539d20a"
QWEN_URL='https://dashscope.aliyuncs.com/compatible-mode/v1'
QWEN_API_KEY='sk-b68836eb4e934d07b7b6c1714cd506c6'
QWEN_MODEL='qwen-vl-max-latest'
#最新的版面测试url
#LAYOUT_CHECK_URL="http://192.168.1.235:30020/v1/dcg_layout"
#生产环境
# LAYOUT_CHECK_URL="http://192.168.161.53:10100/v1/dcg_layout"
#LAYOUT_CHECK_URL="http://192.168.161.53:10100/v1/dcg_layout"
#最新的版面测试url
LAYOUT_CHECK_URL="http://117.147.213.226:30020/v1/dcg_layout"
LAYOUT_CHECK_URL_19="http://117.147.213.226:30019/v1/dcg_layout"
#最新的公式检测服务 TexTeller-生产环境
# FORMULA_DETECTION_URL="http://192.168.161.53:10280/fdet"
#测试环境
#FORMULA_DETECTION_URL="http://127.0.0.1:32039/fdet"
#4090的255上的测试服务
FORMULA_DETECTION_URL="http://127.0.0.1:10039/fdet"
# 后端代理
OPENAI_API_KEY = "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"
FOREIGN_OPENAI_URL = 'http://47.251.70.106:8080/proxy/v1.0/chatgpt/chat'
# # 直接本地代理
# Official_API_KEY = "sk-proj-zo8fJFQDkPhgRnSa7Ck8T3BlbkFJWNKyWmzdABduJWpKSKed"
# # Official_OPENAI_URL = 'https://api.openai.com/v1/chat/completions'
# Official_OPENAI_URL = 'http://47.251.57.88:9399/v1/chat/completions'
import requests
url = "https://open.raysgo.com/aimodel/v1.0/chatGptAccount/getBySceneCode?sceneCode=proofreading"
response = requests.get(url)
Official_API_KEY=response.json()['data']['apiKey']
# 直接本地代理
#Official_API_KEY = "sk-proj-Mw60o2jXioyJwyer9FYrT3BlbkF]n7UXXOF3zb70HcyvofM7"
# Official_OPENAI_URL = 'https://api.openai.com/v1/chat/completions'
Official_OPENAI_URL = 'http://47.251.57.88:9399/v1/chat/completions'
#Textln的秘钥
APP_ID='2bae9d892b488e1d29f4ca0b650ad7a5'
SECRET_CODE='46d25182d74cb02b9fb8a06734fb73a4'
#当前的Prompt:
#提取图像文本和公式内容-----------提取所有的内容!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
# Context_Extract_Prompts = """
# SYSTEM:
# 使用markdown语法,你是一个提取助手,不是一个补充模型。专注于从图片中提取文字转换为markdown格式输出。你必须做到:
# ###1. 提取图中所有的文字内容,根据图中箭头的导向顺序提取内容,,必须保证信息的所属关系顺序,不做任何的修改操作。
# 2. 不要解释和输出无关的文字,直接输出图片中的内容。例如,严禁输出 “以下是我根据图片内容生成的markdown文本:”这样的例子,而是应该直接输出markdown。
# 3. 内容不要包含在```markdown ```中、段落公式使用 $$ $$ 的形式、行内公式使用 $ $ 的形式。解析的内容必须是结构化的,对于表格内容需要保证markdown的。
# ####再次强调,不要解释和输出无关的文字,不要修改图中提取到的任何内容、符号、正负号等等,否则你将受到惩罚。
# USER:
# ""图像""
# Goal:
# ###1.专注于从图片提取所有内容,必须根据图中箭头的导向顺序提取内容,必须保证信息的所属关系顺序,按照markdown的布局展示,否则你将受到惩罚。
# 2.严禁修改提取到的任何内容,不要丢失信息,否则你将受到惩罚。提取到的文本内容都是可以在图中找到的。
# 3.解析的内容必须是结构化的,对于表格内容需要保证markdown的。
# 4.如果是正文文字部分,例如段落序号(一、二....)、【答案】、【例题】、【例】、【例N】这样的内容应该当做一个段落使用标题符号#来提取,如果没有明显的段落标识,则将正文部分组合起来当成一个段落使用##来展示。
# 5. 对于转写段落公式和行内公式的时候:保持原有的字母大小写。不进行任何字符替换或修改。精确复制图像中的每一个字母和符号。
# ####再次强调,不要解释和输出无关的文字,不要修改图中提取到的任何内容、符号、正负号等等,否则你将受到惩罚。
# """
# Context_Extract_Prompts = """
# SYSTEM:
# 使用markdown语法,专注于从图片中提取文字转换为markdown格式输出。你必须做到:
# ###1. 提取图中所有的文字内容,根据图中箭头的导向顺序提取内容,,必须保证信息的所属关系顺序,不做任何的修改操作。
# 2. 不要解释和输出无关的文字,直接输出图片中的内容。例如,严禁输出 “以下是我根据图片内容生成的markdown文本:”这样的例子,而是应该直接输出markdown。
# 3. 内容不要包含在```markdown ```中、段落公式使用 $$ $$ 的形式、行内公式使用 $ $ 的形式。
# ####再次强调,不要解释和输出无关的文字,不要修改图中提取到的任何内容、符号、正负号等等,否则你将受到惩罚。
# USER:
# ""图像""
# Goal:
# ###1.专注于从图片提取所有内容,必须根据图中箭头的导向顺序提取内容,必须保证信息的所属关系顺序,按照markdown的布局展示,否则你将受到惩罚。
# 2.严禁修改提取到的任何内容,不要丢失信息,否则你将受到惩罚。提取到的文本内容都是可以在图中找到的。
# ####再次强调,不要解释和输出无关的文字,不要修改图中提取到的任何内容、符号、正负号等等,否则你将受到惩罚。
# """
#提取图像文本和公式内容-----------提取所有的内容!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
Context_Extract_Prompts = """
SYSTEM:
使用markdown语法,你是一个提取内容和拷贝助手,不是一个修复和修改模型。专注于从图片中提取文字转换为markdown格式输出。你必须做到:
###1. 提取图中所有的文字内容,根据图中箭头的导向顺序提取内容,,必须保证信息的所属关系顺序,不做任何的修改操作。
2. 不要解释和输出无关的文字,直接输出图片中的内容。例如,严禁输出 “以下是我根据图片内容生成的markdown文本:”这样的例子,而是应该直接输出markdown。
3. 内容不要包含在```markdown ```中、段落公式使用 $$ $$ 的形式、行内公式使用 $ $ 的形式、下划线不要省略。解析的内容必须是结构化的,对于表格内容需要保证markdown的。
4.提取的所有内容都放到三个引号中,并前边放置r,避免字符转义例如:r"""'text'"""。
5.如果是正文文字部分,例如段落序号(一、二....)、答案、例题、例、例N这样的内容应该当做一个段落使用标题符号#来提取,内容部分放在以#为标题的下方,不要放到同一行上,保持markdown的标准,如果没有明显的段落标识,则将正文部分组合起来当成一个段落使用##来展示。
####再次强调,不要解释和输出无关的文字,不要修改图中提取到的任何内容、符号、正负号等等,否则你将受到惩罚。
USER:
""图像""
Goal:
###1.专注于从图片提取所有内容,必须根据图中箭头的导向顺序提取内容,必须保证信息的所属关系顺序,按照markdown的布局展示,否则你将受到惩罚。
2.严禁修改提取到的任何内容,不要丢失信息,否则你将受到惩罚。提取到的文本内容都是可以在图中找到的。
3.解析的内容必须是结构化的,对于表格内容需要保证markdown的。
4.如果是正文文字部分,例如段落序号(一、二....)、答案、例题、例、例N这样的内容应该当做一个段落使用标题符号#来提取,内容部分放在以#为标题的下方,一定不要放到同一行上,保持markdown的标准,如果没有明显的段落标识,则将正文部分组合起来当成一个段落使用##来展示。
5. 对于转写段落公式和行内公式的时候:保持原有的字母大小写。不进行任何字符替换或修改。精确复制图像中的每一个字母和符号。
6.提取的所有内容都放到三个引号中,并前边放置r,避免字符转义例如:r"""'text'"""。
####再次强调,不要解释和输出无关的文字,不要修改图中提取到的任何内容、符号、正负号等等,否则你将受到惩罚。
"""
# """
#内容过滤v2-------------------------------------------线上版本!!!!!!!!!!!!!!!!!!!!
# Context_Filter_Prompts = """
# SYSTEM:
# 你将获得一份文本内容。你的任务是专注于提取文本中的公式以及公式的名称和定义内容,过滤掉各种题目内容、表达式内容和变量定义等内容,*不要修改提取的任何文本内容*。如果内容中包含公式则将提取到的公式以及名称与定义,则将公式以及公式的名称和定义内容按照Markdown的格式输出,如果没有公式则输出空字符串'',不要自己编造和修改文本,否则你受到惩罚。
# 公式是用符号和变量表示的数学关系或规律,是解决特定类型问题的工具。例如,面积公式、体积公式等。仅仅是中文描述信息一定没有公式。你专注于分辨公式与表达式的区别,因为你能很简单的过滤掉变量。表达式:x+y。变量:x,y,z=5等。公式是经过严格证明的数学命题,是在特定公理体系下通过逻辑推理得出的结论。例如,勾股定理就是一个经典的数学定理。
# ###必须满足:提取公式的同时需要提取前三行的上下文信息,如果公式是包含在$符号之间需要提取前三行描述公式的文本内容,包含在$$之间的需要提取临近三行的文本内容。表格中的公式需要提取表格内容和结构内容。一定不要包含在```markdown ```中。
# USER:
# ""<input>""
# Goal:
# 1.你的任务是专注于提取文本中的公式以及公式的名称和定义内容,过滤掉各种题目内容、表达式内容和变量定义等内容,有结果值的都是表达式需要过滤,*不要修改提取的任何文本内容*。如果内容中包含公式则将提取到的公式以及名称与定义,则将公式以及公式的名称和定义内容按照Markdown的格式输出,如果没有公式则输出空字符串'',不要自己编造和修改文本,否则你受到惩罚。
# 2.*不要修改提取的任何内容*,否则你将受到惩罚!过滤后的文本内容需要和输入的数据式一致性的,可以来复查的。一定不要包含在```markdown ```中。
# 3.如果输入没有任何公式,那么它一定是没有公式,输出空字符串即可,严禁按照输入信息补写、续写、改写等,否则你将受到惩罚。
# 4.将提取到的公式、公式对应的名称以及定义等文本信息按照markdown的形式输出,仅将抽取的结果输出。
# # 5.必须过滤掉类似于这样的运算式和式子:r"\overrightarrow{BA}=\overrightarrow{OC}=a-b"、r"A = A_1A_2 \cup A_1 \overline{A_2} \cup \overline{A_1}A_2"、r"A_1 \\overline{A_2} = A_1 \\overline{A_2}"、'$$\\cdots x-2 \\leq 0, x+3 \\geq 0, x-5 < 0.$$'、a-2=0、b+1/2=0、r"\therefore x = \pm \sqrt{1800}。"、r'\Delta v_{\text{飞}} = 300 - 0 = 300(km/s) = 83(m/s)'、 r"$|x-2|+\sqrt{(x+3)^2}+\sqrt{x^2-10x+25}$"、r"\overrightarrow{CA} = \overrightarrow{OA} - \overrightarrow{OC}=__"、r"$h = 5t^2$",x = 30 \sqrt{2}、'\\overrightarrow{AC}=\\overrightarrow{AB}+\\overrightarrow{AD}=a+b', '\\overrightarrow{DB}=\\overrightarrow{AB}-\\overrightarrow{AD}=a-b'。
# # 5.2 具有推导过程的都需要过滤,必须过滤掉的运算式案例:1^2 + 2^2 = \frac{2 \times 3 \times 5}{6}、y=\frac{1}{2}、r"\overrightarrow{CA} = \overrightarrow{OA} - \overrightarrow{OC}"、 r"\cos (A - B) \cos (B - C) \cos (C - A) = 1"、r"\\overrightarrow{AC}=\\overrightarrow{AB}+\\overrightarrow{AD}=a+b"、't+\\frac{1}{2}\\lambda=\\frac{3}{2}\\lambda+1=0'、'\\frac{2}{30} + \\frac{8}{30} + \\frac{8}{30} + \\frac{18}{30} = \\frac{3}{5}'、
# # 输出的内容中的公式一定符合Latex的标准格式,去除无用的符号和文本!否则你将受到惩罚。
# """
# # -------------------------------------新版本的过滤?-----------------------------------------
# Context_Filter_Prompts = """你负责从用户输入的markdown中提取包含的公式以及公式的描述提的文本取出来,过滤掉无用的文本、表达式和变量定义。一定不要修改提取的任何文本内容,确保输出的文本中的公式格式符合标准的Latex格式,可以删减噪音字符。
# 提取公式及其描述的过程中需要满足以下要求:
# 1.公式是用符号和变量表示的数学关系或规律,是解决特定类型问题的工具。例如,面积公式、体积公式等。仅仅是中文描述信息一定没有公式。你专注于分辨公式与表达式的区别,因为你能很简单的过滤掉变量。表达式:x+y。变量:x,y,z=5等。公式是经过严格证明的数学命题,是在特定公理体系下通过逻辑推理得出的结论。例如,勾股定理就是一个经典的数学定理。
# 2.必须满足:提取公式的同时需要提取前三行的上下文信息,如果公式是包含在$符号之间需要提取前三行描述公式的文本内容,包含在$$之间的需要提取临近三行的文本内容。表格中的公式需要提取表格内容和结构内容。
# 3.不要修改提取的任何内容,否则你将受到惩罚!一定不要包含在```markdown ```中。
# 4.不要修改任何内容。
# # 用户输入的markdown内容:
# <input>
# # 最后输出包含公式以及描述的内容。
# """
# -------------------------------------新版本的过滤?-----------------------------------------
Context_Filter_Prompts = """你负责从用户输入的markdown中提取包含的公式以及公式的描述提的文本取出来,过滤掉无用的文本、表达式和变量定义。一定不要修改提取的任何文本内容,确保输出的文本中的公式格式符合标准的Latex格式,可以删减噪音字符。
# 需要保留的内容:
1.定律:电功、电功率、一元二次方程、焦耳定律等等。
2.此类式子也属于公式:$( \sqrt { a } ) ^ { 2 } = a ( a \geq 0 )$ 、$\sqrt { a ^ { 2 } } = a ( a \geq 0 )$
# 需要过滤掉的内容:
1.带有数值运算的计算式,即等号右边是明显是数值。例如:P+Q=1.2, 1+2=3, 1*4=4、AC+BC=15、$R _ { 灯 } = \frac { E } { I _ { 1 } } - r = 1 . 2 \Omega$等。
2.数值定义赋值式,例如A=100,赋值。例如:$P _ { 灯 1 } = 1 2 0 W$、A=15、$I _ { 2 } = 5 8$A等。
3.无意义的比较式。例如:A>C、W > Q、A+B<C+D等
4.过滤掉选项的内容。例如:B.当$n _ { 1 } > n _ { 2 }$时,电流方向从A→B,电流强度$I=\frac { ( n _ { 1 } - n _ { 2 } ) e } { t }$、C.当$n _ { 1 } < n _ { 2 }$时,电流方向从B→A,电流强度$I = \frac { ( n _ { 2 } - n _ { 1 } ) e } { t }$等等。
# 用户输入的markdown内容:
<input>
# 最后输出包含公式以及描述的内容,输出的内容必须在用户输入的内容存在!!
# 提取内容,不要修改原始文本中的任何内容,否则你将受到惩罚!
# 最后文本中的公式应该用Latex的格式严格表示。
"""
#公式审校----添加公式名称的纠错能力!
#公式审校----添加公式名称的纠错能力!
# Context_Checker_Prompts = """
# SYSTEM:
# 你将获得一份文本内容,包含了公式和文本内容,文本内容可能包含公式的名称以及定义。你的任务有三个:(1)找到出错的所有定理公式和公式的名称与描述(2)出错的原因(3)根据公式的名称、上下文内容和定义来将找到的出错公式纠正。如果存在错误,则最后将所有(出错的定理公式和公式名称)、出错原因和对应的(纠正后的定理公式和公式名称) 按照{"error_formula":[],"error_reason":[],"corrected_formula":[]}的格式输出。如果没有错误则返回空的{"error_formula":[],"error_reason":[],"corrected_formula":[]},没有错误不要自己创造错误,也不要随机捏造错误原因,否则将受到惩罚。
# 根据公式的名称以及上下文来发现错误并纠正,公式与常见的用法和定义匹配,即可认为公式是正确的。没有上下文的情况下,你只需要输出置信的结果,缺少信息,仅关注公式本身的错误,例如:U1=U2=U3等是正确的,不要考虑使用条件的问题,是正确的,没有缺少等号这样的错误,可能是没有识别到。如果包含文本,那么考虑文本与公式的对应关系。
# USER:
# ""<input>""
# Goal:
# 1. 你的任务有三个:(1)找到出错的所有定理公式和公式的名称与描述(2)出错的原因(3)根据公式的名称、上下文内容和定义来将找到的出错公式纠正。最后将所有(出错的定理公式和公式名称)、出错原因和对应的(纠正后的定理公式和公式名称) 按照{"error_formula":[],"error_reason":[],"corrected_formula":[]}的格式输出。如果没有错误则返回空的{"error_formula":[],"error_reason":[],"corrected_formula":[]},不要自己创造错误,否则将受到乘法。
# 2. 定理公式的科目是:数学和物理,物理中有一些公式常见的用法:电阻使用大写R、在电荷的定义表示中使用大写Q、自感电动势中使用大写的I、磁感应强度使用B、W电工和Q焦耳定律式是需要时间变量t、变量名使用错误、极限符号不全(lim)、单位缺少平方、多或少变量、2次方和3次方混用、大小写错误、加号和减号混用、在存在上下文前提下,使用的符号不是上下文中定义的符号、周期与频率符号的正确使用、多余的变量的添加、大小号、小于号、等于号混用。
# 2.2 数学中有一些常见问题:公式的标准形式写错、乘法与加法符号混淆、2次方符号错写成3次方、公式与描述不符,例如\sqrt{}是一次根号、根据描述,公式的大于号与小于号写反、正负号写错、缺少平方符号、公式变量缺少范围、变量的范围错误。公式范围在小学、初中以及高中课本中能找到明确定义的内容,一步步思考一个公式的合理性,合理则认为正确。如果没有公式对应的文本信息,只关注公式本身的错误,不要考虑公式是否适用于当前情况。如果包含文本,那么考虑公式的名称与公式的对应关系。
# 3. 再次强调,输出格式严格按照{"error_formula":[],"error_reason":[],"corrected_formula":[]},否则你将受到惩罚!内容不要包含在```markdown ```中。其中error_formula和corrected_formula的内容要标准的Latex格式,可能包含公式的名称。
# 4. 不要使用无法解析的Latex符号,例如\\rac、\\x0crac,除法使用\\frac。不要添加过多的\\。error_reason要简洁。输出的结果中error_formula和corrected_formula中的每一个Latex公式使用单斜杠,避免转义造成错误。内容不要包含在```markdown ```中。
# # 输出结果的格式中的所有内容的字符串前都加一个r,例如:r"error_formula"、r"error_reason"、r"corrected_formula",务必添加以避免转义错误!!否则你将受到惩罚!
# # 输出结果{"error_formula":[],"error_reason":[],"corrected_formula":[]}每一个列表中的数据个数肯定是相等的。
# """
# Context_Checker_Prompts = """
# 系统提示:
# 你将获得一份文本内容,包含了公式和文本内容,文本内容可能包含公式的名称以及定义。你的任务有三个:(1)找到出错的所有定理公式和公式的名称与描述(2)出错的原因(3)根据公式的名称、上下文内容和定义来将找到的出错公式纠正。如果存在错误,则最后将所有(出错的定理公式和公式名称)、出错原因和对应的(纠正后的定理公式和公式名称)
# 根据公式的名称以及上下文来发现错误并纠正,公式与常见的用法和定义匹配,即可认为公式是正确的。没有上下文的情况下,你只需要输出置信的结果,缺少信息,仅关注公式本身的错误,例如:U1=U2=U3等是正确的,不要考虑使用条件的问题,是正确的,没有缺少等号这样的错误,可能是没有识别到。如果包含文本,那么考虑文本与公式的对应关系。
# 用户输入文本:
# ""<input>""
# 目标:
# 1. 你的任务有三个:(1)找到出错的所有定理公式和公式的名称与描述(2)出错的原因(3)根据公式的名称、上下文内容和定义来将找到的出错公式纠正。
# 2. 定理公式的科目是:数学和物理,物理中有一些公式常见的用法:电阻使用大写R、在电荷的定义表示中使用大写Q、自感电动势中使用大写的I、磁感应强度使用B、W电工和Q焦耳定律式是需要时间变量t、变量名使用错误、极限符号不全(lim)、单位缺少平方、多或少变量、2次方和3次方混用、大小写错误、加号和减号混用、在存在上下文前提下,使用的符号不是上下文中定义的符号、周期与频率符号的正确使用、多余的变量的添加、大小号、小于号、等于号混用。
# 2.2 数学中有一些常见问题:公式的标准形式写错、乘法与加法符号混淆、2次方符号错写成3次方、公式与描述不符,例如\sqrt{}是一次根号、根据描述,公式的大于号与小于号写反、正负号写错、缺少平方符号、公式变量缺少范围、变量的范围错误。公式范围在小学、初中以及高中课本中能找到明确定义的内容,一步步思考一个公式的合理性,合理则认为正确。如果没有公式对应的文本信息,只关注公式本身的错误,不要考虑公式是否适用于当前情况。如果包含文本,那么考虑公式的名称与公式的对应关系。
# ##输出结果要求
# 1. 再次强调,输出格式严格按照{r"error_formula":[],r"error_reason":[],r"corrected_formula":[]},否则你将受到惩罚!内容不要包含在```markdown ```中。其中error_formula和corrected_formula的内容要标准的Latex格式,可能包含公式的名称。
# 2. 不要使用无法解析的Latex符号,例如\\rac、\\x0crac,除法使用\\frac。不要添加过多的\\。
# 3.error_reason内的内容要简洁。输出的结果中error_formula和corrected_formula中的每一个Latex公式使用单斜杠,避免转义造成错误。内容不要包含在```markdown ```中。
# ##输出格式要求
# # 输出结果的格式中的所有内容的字符串前都加一个r,例如:r"error_formula"、r"error_reason"、r"corrected_formula",务必添加以避免转义错误!!否则你将受到惩罚!
# # 输出结果{"error_formula":[],"error_reason":[],"corrected_formula":[]}每一个列表中的数据个数肯定是相等的。error_reason内的内容要简洁
# """
# -------------------------------------------###生产版本!!!!!!!!!!!!--------------最新结果版本!!!!!!!!
# Context_Checker_Prompts = """
# ## 角色和任务:
# 你是一名数学和物理科目的高级学者,负责做这两个学科的公式的校对。公式指的是所有的定理公式,而不是表达式或者变量或者计算错误等。
# 你的任务有三个:(1)从用户输入的文本找到出错的所有定理公式和公式的名称与描述(2)出错的原因(3)根据公式的名称、上下文内容和定义来将找到的出错公式纠正。
# ## 内容提示:
# 1. 物理知识。物理中有一些公式常见的用法:电阻使用大写R、在电荷的定义表示中使用大写Q、自感电动势中使用大写的I、磁感应强度使用B、W电工和Q焦耳定律式是需要时间变量t、变量名使用错误、极限符号不全(lim)、单位缺少平方、多或少变量、2次方和3次方混用、大小写错误、加号和减号混用、在存在上下文前提下,使用的符号不是上下文中定义的符号、周期与频率符号的正确使用、多余的变量的添加、大小号、小于号、等于号混用。
# 2. 数学知识。数学中有一些常见问题:公式的标准形式写错、乘法与加法符号混淆、2次方符号错写成3次方、公式与描述不符,例如\sqrt{}是一次根号、根据描述,公式的大于号与小于号写反、正负号写错、缺少平方符号、公式变量缺少范围、变量的范围错误。公式范围在小学、初中以及高中课本中能找到明确定义的内容,一步步思考一个公式的合理性,合理则认为正确。如果没有公式对应的文本信息,只关注公式本身的错误,不要考虑公式是否适用于当前情况。如果包含文本,那么考虑公式的名称与公式的对应关系。
# ## 用户输入文本内容:
# ""<input>""
# ## 严格按照下面的输出格式:
# {r"error_formula":[r"",r"",],r"error_reason":[r"",r"",],r"corrected_formula":[r"",r"",]}
# ## 输出结果要求:
# 1.内容不要包含在```markdown ```中。
# 2.其中error_formula和corrected_formula对应的列表内的公式要使用标准的Latex格式。
# 3.所有内容的字符串前都加一个r,例如:r"error_formula"、r"error_reason"、r"corrected_formula",务必添加以避免转义错误!!否则你将受到惩罚!
# 4.error_formula中存储着每一个错误的公式表示,error_reason中存储着每一个错误原因,corrected_formula中存储着每一个纠正后的公式,包含的文本的数量也是相同的。
# """
#---------------------------新的
Context_Checker_Prompts = """ ## 角色和任务:
你是一名数学和物理科目的数学公式、定理审校专家,负责检查和纠正定理公式是否符合其定义。公式指的是所有的定理公式,而不是表达式或者变量或者计算错误等。
## 内容提示:
1. 物理知识。电阻使用大写R、在电荷的定义表示使用大写Q、自感电动势使用大写I、磁感应强度使用B、电工和焦耳定律式是需要乘变量t、变量名使用错误、极限符号不全(lim)、单位缺少平方、多或少变量、2次方和3次方混用、大小写错误、加号和减号混用、在存在上下文前提下,使用的符号不是上下文中定义的符号、周期与频率符号的正确使用、多余的变量的添加、大小号、小于号、等于号混用。
2. 数学知识。常见错误:公式的标准形式写错、乘法与加法符号混淆、2次方符号错写成3次方、公式与描述不符,例如\sqrt{}是一次根号、根据描述,公式的大于号与小于号写反、正负号写错、缺少平方符号、公式变量缺少范围、变量的范围错误。公式范围在小学、初中以及高中课本中能找到明确定义的内容,一步步思考一个公式的合理性,合理则认为正确。如果没有公式对应的文本信息,只关注公式本身的错误,不要考虑公式是否适用于当前情况。如果包含文本,那么考虑公式的名称与公式的对应关系。
## 用户输入文本内容:
""<input>""
## 严格按照下面的输出格式:
{r"error_formula":[r"",r"",],r"error_reason":[r"",r"",],r"corrected_formula":[r"",r"",]}
## 输出结果要求:
1.内容不要包含在```markdown ```中。
2.其中error_formula和corrected_formula对应的列表内的公式要使用标准的Latex格式。
3.所有内容的字符串前都加一个r,例如:r"error_formula"、r"error_reason"、r"corrected_formula",务必添加以避免转义错误!!否则你将受到惩罚!
4.error_formula中存储着每一个错误的公式表示,error_reason中存储着每一个错误原因,corrected_formula中存储着每一个纠正后的公式,包含的文本的数量也是相同的。
"""
# Context_Checker_Prompts = """你是一名数学和物理科目的高级专家,负责对公式进行检查并纠正。公式指的是所有的定理公式,而不是表达式或者变量或者计算错误等。
# 按照以下提示和要求对公式进行检查纠正:
# ## 内容提示:
# 1. 物理知识。物理中有一些公式常见的用法:电阻使用大写R、在电荷的定义表示中使用大写Q、自感电动势中使用大写的I、磁感应强度使用B、W电工和Q焦耳定律式是需要时间变量t、变量名使用错误、极限符号不全(lim)、单位缺少平方、多或少变量、2次方和3次方混用、大小写错误、加号和减号混用、在存在上下文前提下,使用的符号不是上下文中定义的符号、周期与频率符号的正确使用、多余的变量的添加、大小号、小于号、等于号混用。
# 2. 数学知识。数学中有一些常见问题:公式的标准形式写错、乘法与加法符号混淆、2次方符号错写成3次方、公式与描述不符,例如\sqrt{}是一次根号、根据描述,公式的大于号与小于号写反、正负号写错、缺少平方符号、公式变量缺少范围、变量的范围错误。公式范围在小学、初中以及高中课本中能找到明确定义的内容,一步步思考一个公式的合理性,合理则认为正确。如果没有公式对应的文本信息,只关注公式本身的错误,不要考虑公式是否适用于当前情况。如果包含文本,那么考虑公式的名称与公式的对应关系。
# ## 输出结果要求:
# 1.内容不要包含在```markdown ```中。
# 2.其中error_formula和corrected_formula对应的列表内的公式要使用标准的Latex格式,同时需要是完整的公式,使用的符号需要和原始公式对齐,包括大小写。不要在这两个部分添加文字说明。
# 3.所有内容的字符串前都加一个r,例如:r"error_formula"、r"error_reason"、r"corrected_formula",务必添加以避免转义错误!!否则你将受到惩罚!
# 4.错误的原因error_reason 要简洁。
# ## 严格按照下面的输出格式:
# {r"error_formula":[r"",r"",],r"error_reason":[r"",r"",],r"corrected_formula":[r"",r"",]}
# ## 用户输入文本内容:
# <input>
# """
###############################################################
###############################################################
# 特殊字符处理
pdf_character_map = {
'\u2002': ' ',
'\u3000': ' ',
'\u2003': ' ',
}
#-----------------生产环境同测试环境
MYSQL_DB_URL = "mysql+pymysql://lanhaibin:nF3uF&Ewc7mE@122.112.214.103:3306/dcg-ai-knowledge-base?charset=utf8"
# # 测试环境mysql集群节点
# # MYSQL_DB_URL = "mysql+pymysql://root:dcg$AI123@192.168.1.45:3306/dcg-ai-knowledge-base?charset=utf8"
# # 线上环境
# MYSQL_DB_URL = "mysql+pymysql://lanhaibin:nF3uF&Ewc7mE@122.112.214.103:3306/dcg-ai-knowledge-base?charset=utf8"
# # 测试环境的 API_KEY & ENDPOINT
# #OPENAI_API_KEY_URL = "https://adviser.raysgo.com/aimodel/v1.0/chatGptAccount/getBySceneCode?sceneCode=pdf-parse"
# # 正式环境获取API_KEY & ENDPOINT
# OPENAI_API_KEY_URL = "https://adviser.5rs.me/aimodel/v1.0/chatGptAccount/getBySceneCode?sceneCode=pdf-parse"
# # 线上OCR
# OCR_URL="http://10.233.5.68:8080/v1/dcg_ocr"
# #OCR_URL="https://dcg-ai-red-list.5rs.me/v1/dcg_ocr"
# # 线上版面检测
# LAYOUT_CHECK_URL="http://192.168.161.53:10100/v1/dcg_layout"
# #LAYOUT_CHECK_URL="http://192.168.1.235:30016/v1/dcg_layout"
# GPT_MAP = {
# "gpt3.5": "gpt-3.5-turbo-0125",
# "gpt4": "gpt-4-0125-preview",
# "gpt4o":"gpt-4o"
# }
# # 后端代理
# OPENAI_API_KEY = "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"
# FOREIGN_OPENAI_URL = 'http://47.251.70.106:8080/proxy/v1.0/chatgpt/chat'
# # 直接本地代理
# Official_API_KEY = "sk-proj-zo8fJFQDkPhgRnSa7Ck8T3BlbkFJWNKyWmzdABduJWpKSKed"
# # Official_OPENAI_URL = 'https://api.openai.com/v1/chat/completions'
# Official_OPENAI_URL = 'http://47.251.57.88:9399/v1/chat/completions'
# PROMPT="""
# {page_content}
# -----------------------
# 1. 请你根据以上内容,请你找出其中书的目录结构,要求即详细又精确,按照顺序输出格式为:['第一章章节名称', '第二章章节名称']
# 2. 如果找不出,输出: []
# """
# #当前的Prompt:
# #提取图像文本和公式内容-----------提取所有的内容!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
# Context_Extract_Prompts = """
# SYSTEM:
# 使用markdown语法,你是一个提取内容和拷贝助手,不是一个修复和修改模型。专注于从图片中提取文字转换为markdown格式输出。你必须做到:
# ###1. 提取图中所有的文字内容,根据图中箭头的导向顺序提取内容,,必须保证信息的所属关系顺序,不做任何的修改操作。
# 2. 不要解释和输出无关的文字,直接输出图片中的内容。例如,严禁输出 “以下是我根据图片内容生成的markdown文本:”这样的例子,而是应该直接输出markdown。
# 3. 内容不要包含在```markdown ```中、段落公式使用 $$ $$ 的形式、行内公式使用 $ $ 的形式、下划线不要省略。解析的内容必须是结构化的,对于表格内容需要保证markdown的。
# 4.提取的所有内容都放到三个引号中,并前边放置r,避免字符转义例如:r"""'text'"""。
# 5.如果是正文文字部分,例如段落序号(一、二....)、答案、例题、例、例N这样的内容应该当做一个段落使用标题符号#来提取,内容部分放在以#为标题的下方,不要放到同一行上,保持markdown的标准,如果没有明显的段落标识,则将正文部分组合起来当成一个段落使用##来展示。
# ####再次强调,不要解释和输出无关的文字,不要修改图中提取到的任何内容、符号、正负号等等,否则你将受到惩罚。
# USER:
# ""图像""
# Goal:
# ###1.专注于从图片提取所有内容,必须根据图中箭头的导向顺序提取内容,必须保证信息的所属关系顺序,按照markdown的布局展示,否则你将受到惩罚。
# 2.严禁修改提取到的任何内容,不要丢失信息,否则你将受到惩罚。提取到的文本内容都是可以在图中找到的。
# 3.解析的内容必须是结构化的,对于表格内容需要保证markdown的。
# 4.如果是正文文字部分,例如段落序号(一、二....)、答案、例题、例、例N这样的内容应该当做一个段落使用标题符号#来提取,内容部分放在以#为标题的下方,一定不要放到同一行上,保持markdown的标准,如果没有明显的段落标识,则将正文部分组合起来当成一个段落使用##来展示。
# 5. 对于转写段落公式和行内公式的时候:保持原有的字母大小写。不进行任何字符替换或修改。精确复制图像中的每一个字母和符号。
# 6.提取的所有内容都放到三个引号中,并前边放置r,避免字符转义例如:r"""'text'"""。
# ####再次强调,不要解释和输出无关的文字,不要修改图中提取到的任何内容、符号、正负号等等,否则你将受到惩罚。
# """
# # """
# #内容过滤v2
# Context_Filter_Prompts = """
# SYSTEM:
# 你将获得一份文本内容。你的任务是专注于提取文本中的公式以及公式的名称和定义内容,过滤掉各种题目内容、表达式内容和变量定义等内容,*不要修改提取的任何文本内容*。如果内容中包含公式则将提取到的公式以及名称与定义,则将公式以及公式的名称和定义内容按照Markdown的格式输出,如果没有公式则输出空字符串'',不要自己编造和修改文本,否则你受到惩罚。
# 公式是用符号和变量表示的数学关系或规律,是解决特定类型问题的工具。例如,面积公式、体积公式等。仅仅是中文描述信息一定没有公式。你专注于分辨公式与表达式的区别,因为你能很简单的过滤掉变量。表达式:x+y。变量:x,y,z=5等。公式是经过严格证明的数学命题,是在特定公理体系下通过逻辑推理得出的结论。例如,勾股定理就是一个经典的数学定理。
# ###必须满足:提取公式的同时需要提取前三行的上下文信息,如果公式是包含在$符号之间需要提取前三行描述公式的文本内容,包含在$$之间的需要提取临近三行的文本内容。表格中的公式需要提取表格内容和结构内容。一定不要包含在```markdown ```中。
# USER:
# ""<input>""
# Goal:
# 1.你的任务是专注于提取文本中的公式以及公式的名称和定义内容,过滤掉各种题目内容、表达式内容和变量定义等内容,有结果值的都是表达式需要过滤,*不要修改提取的任何文本内容*。如果内容中包含公式则将提取到的公式以及名称与定义,则将公式以及公式的名称和定义内容按照Markdown的格式输出,如果没有公式则输出空字符串'',不要自己编造和修改文本,否则你受到惩罚。
# 2.*不要修改提取的任何内容*,否则你将受到惩罚!过滤后的文本内容需要和输入的数据式一致性的,可以来复查的。一定不要包含在```markdown ```中。
# 3.如果输入没有任何公式,那么它一定是没有公式,输出空字符串即可,严禁按照输入信息补写、续写、改写等,否则你将受到惩罚。
# 4.将提取到的公式、公式对应的名称以及定义等文本信息按照markdown的形式输出,仅将抽取的结果输出。
# # 5.必须过滤掉类似于这样的运算式和式子:r"\overrightarrow{BA}=\overrightarrow{OC}=a-b"、r"A = A_1A_2 \cup A_1 \overline{A_2} \cup \overline{A_1}A_2"、r"A_1 \\overline{A_2} = A_1 \\overline{A_2}"、'$$\\cdots x-2 \\leq 0, x+3 \\geq 0, x-5 < 0.$$'、a-2=0、b+1/2=0、r"\therefore x = \pm \sqrt{1800}。"、r'\Delta v_{\text{飞}} = 300 - 0 = 300(km/s) = 83(m/s)'、 r"$|x-2|+\sqrt{(x+3)^2}+\sqrt{x^2-10x+25}$"、r"\overrightarrow{CA} = \overrightarrow{OA} - \overrightarrow{OC}=__"、r"$h = 5t^2$",x = 30 \sqrt{2}、'\\overrightarrow{AC}=\\overrightarrow{AB}+\\overrightarrow{AD}=a+b', '\\overrightarrow{DB}=\\overrightarrow{AB}-\\overrightarrow{AD}=a-b'。
# # 5.2 具有推导过程的都需要过滤,必须过滤掉的运算式案例:1^2 + 2^2 = \frac{2 \times 3 \times 5}{6}、y=\frac{1}{2}、r"\overrightarrow{CA} = \overrightarrow{OA} - \overrightarrow{OC}"、 r"\cos (A - B) \cos (B - C) \cos (C - A) = 1"、r"\\overrightarrow{AC}=\\overrightarrow{AB}+\\overrightarrow{AD}=a+b"、't+\\frac{1}{2}\\lambda=\\frac{3}{2}\\lambda+1=0'、'\\frac{2}{30} + \\frac{8}{30} + \\frac{8}{30} + \\frac{18}{30} = \\frac{3}{5}'、
# """
# #公式审校----添加公式名称的纠错能力!
# Context_Checker_Prompts = """
# SYSTEM:
# 你将获得一份文本内容,包含了公式和文本内容,文本内容可能包含公式的名称以及定义。你的任务有三个:(1)找到出错的所有定理公式和公式的名称与描述(2)出错的原因(3)根据公式的名称、上下文内容和定义来将找到的出错公式纠正。如果存在错误,则最后将所有(出错的定理公式和公式名称)、出错原因和对应的(纠正后的定理公式和公式名称) 按照{"error_formula":[],"error_reason":[],"corrected_formula":[]}的格式输出。如果没有错误则返回空的{"error_formula":[],"error_reason":[],"corrected_formula":[]},没有错误不要自己创造错误,也不要随机捏造错误原因,否则将受到乘法。
# 根据公式的名称以及上下文来发现错误并纠正,公式与常见的用法和定义匹配,即可认为公式是正确的。没有上下文的情况下,你只需要输出置信的结果,缺少信息,仅关注公式本身的错误,例如:U1=U2=U3等是正确的,不要考虑使用条件的问题,是正确的,没有缺少等号这样的错误,可能是没有识别到。如果包含文本,那么考虑文本与公式的对应关系。
# USER:
# ""<input>""
# Goal:
# 1. 你的任务有三个:(1)找到出错的所有定理公式和公式的名称与描述(2)出错的原因(3)根据公式的名称、上下文内容和定义来将找到的出错公式纠正。最后将所有(出错的定理公式和公式名称)、出错原因和对应的(纠正后的定理公式和公式名称) 按照{"error_formula":[],"error_reason":[],"corrected_formula":[]}的格式输出。如果没有错误则返回空的{"error_formula":[],"error_reason":[],"corrected_formula":[]},不要自己创造错误,否则将受到乘法。
# 2. 定理公式的科目是:数学和物理,物理中有一些公式常见的用法:电阻使用大写R、在电荷的定义表示中使用大写Q、自感电动势中使用大写的I、磁感应强度使用B、W电工和Q焦耳定律式是需要时间变量t、变量名使用错误、极限符号不全(lim)、单位缺少平方、多或少变量、2次方和3次方混用、大小写错误、加号和减号混用、在存在上下文前提下,使用的符号不是上下文中定义的符号、周期与频率符号的正确使用、多余的变量的添加、大小号、小于号、等于号混用。
# 2.2 数学中有一些常见问题:公式的标准形式写错、乘法与加法符号混淆、2次方符号错写成3次方、公式与描述不符,例如\sqrt{}是一次根号、根据描述,公式的大于号与小于号写反、正负号写错、缺少平方符号、公式变量缺少范围、变量的范围错误。公式范围在小学、初中以及高中课本中能找到明确定义的内容,一步步思考一个公式的合理性,合理则认为正确。如果没有公式对应的文本信息,只关注公式本身的错误,不要考虑公式是否适用于当前情况。如果包含文本,那么考虑公式的名称与公式的对应关系。
# 3. 再次强调,输出格式严格按照{"error_formula":[],"error_reason":[],"corrected_formula":[]},否则你将受到惩罚!内容不要包含在```markdown ```中。其中error_formula和corrected_formula的内容要标准的Latex格式,可能包含公式的名称。
# 4. 不要使用无法解析的Latex符号,例如\\rac、\\x0crac,除法使用\\frac。不要添加过多的\\。error_reason要简洁。输出的结果中error_formula和corrected_formula中的每一个Latex公式使用单斜杠,还有error_reason,字符串前都加一个r,例如:r"formula"、r"reason",避免转义造成错误。内容不要包含在```markdown ```中。
# """
# # 特殊字符处理
# pdf_character_map = {
# '\u2002': ' ',
# '\u3000': ' ',
# '\u2003': ' ',
# }
# VLM_Match_User_Prompt="""1. 根据提供的{{image_url}}加载并查看图像并理解。
# 2. 阅读图像标题{{user_text}},注意其中的作品名称、人物、实体名称、颜色、纹理、数量、情景和角色等等可能的元素。
# 3. 比较图像标题与图像理解描述,判断两者是否匹配。
# 4. 若得到的信息无法判断,则认定为当前图像和标题为“匹配”。
# 5. 如果匹配,设置{{match_result}}为"匹配";如果不匹配,设置{{match_result}}为"不匹配"。
# 6. 根据比较结果,编写一个解释作为{{reason}},说明为什么匹配或不匹配。
# 7. 输出格式应严格按照以下格式:
# 匹配结果:{{match_result}}
# 原因:{{reason}}
# """
# VLM_Match_User_Prompt="""1. 根据提供的{{image_url}}加载并查看图像并理解。
# 2. 阅读图像标题{{user_text}},注意其中的作品名称、人物、实体名称、颜色、纹理、数量、情景和角色等等可能的元素。
# 3. 比较图像标题与图像理解描述,判断两者是否匹配。
# #4. 若得到的信息无法判断或者不完全耦合的情况,存在可能包含的关系,则设置最后的结果为"匹配"。
# 5. 如果匹配,设置{{match_result}}为"匹配";如果不匹配,设置{{match_result}}为"不匹配"。
# 6. 根据比较结果,编写一个解释作为{{reason}},说明为什么匹配或不匹配。
# 7. 输出格式应严格按照以下格式:
# 匹配结果:{{match_result}}
# 原因:{{reason}}
# """
### 对齐测试 正在使用
VLM_Match_User_Prompt="""1. 根据提供的{{image_url}}加载并查看图像并理解。
2. 阅读图像标题{{user_text}},注意其中的作品名称、人物、实体名称、颜色、纹理、数量、情景和角色等等可能的元素。
3. 比较图像标题与图像理解描述,判断两者是否匹配。
#4. 若得到的信息无法判断或者不完全耦合的情况,存在可能包含的关系,则设置最后的结果为"匹配"。
5. 如果匹配,设置{{match_result}}为"匹配";如果不匹配,设置{{match_result}}为"不匹配"。
6. 根据比较结果,编写一个解释作为{{reason}},说明为什么匹配或不匹配。
7. 输出格式应严格按照以下格式:
匹配结果:{{match_result}}
原因:{{reason}}
"""
# VLM_Match_User_Prompt="""# 图像-标题匹配判断
# ## 执行步骤
# 1. **观察图像**{{image_url}}:识别人物或实体数量、实体名称和物体特征、实体位置(上面、中间、下面)、颜色纹理、文字、场景环境、动作状态。
# 2. **解析标题**{{user_text}}:提取关键描述要素(人物或实体数量、实体名称和物体特征、实体位置和方位、颜色纹理、文字、场景环境、动作状态。等)
# 3. **对比验证**:逐项检查核心要素是否一致
# - **必须匹配**:实体名称、实体数量、实体位置、颜色纹理、文字匹配等
# - **重点关注**:识别人物或实体数量、实体名称和物体特征、实体位置(上面、中间、下面)、颜色纹理、文字、场景环境、动作状态。
# - **允许差异**:艺术风格、视角角度、细微色差、表现手法
# 4. **判断标准**:
# - 匹配:标题准确反映图像核心内容
# - 不匹配:存在明显矛盾或错误描述
# - 信息不足或模糊时默认为"匹配"
# 5. 案例信息:
# case1:数量问题。图中描述只有一个、一根、一扇、一种实体。 文本中描述的是三个、三根、三扇或者三种实体。 故而,不匹配。
# case2:实体问题。图中描述的实体是A(风筝),文本描述中找不到A的描述,反而在描述实体B(大树)。故而,不匹配。
# case3:实体和人物混淆。图中描述的一幅作品、一个实体装饰、一群人。文中则在描述创作的人、场景、情景或者世间。故而,不匹配。
# case4:实体没有包含关系。图中描述了一个场景,例如:嬉戏打闹 ,一个红灯笼,一颗大树。文中没有找到和图中描述的实体。故而,不匹配。
# case5:实体属性不符。图中有一群男人。文中描述了女人。描述的实体属性错误,不包含。故而,不匹配。
# case6:文字不符。图中描述的文字包括:天下一统、第K次会议等。文中描述的内容和图中描述不符,文中是天下大乱、第N次会议。故而,不匹配。
# case7:颜色不符。图中描述的实体颜色和文中不符。
# 6. 如果匹配,设置{{match_result}}为"匹配";如果不匹配,设置{{match_result}}为"不匹配"。
# 7. 原因尽量精简,可以有每个问题简短的分析过程!
# 严格遵循输出格式,不要添加#等符号:
# 匹配结果:{match_result}
# 原因:{reason}
# [简要说明匹配/不匹配的具体原因,重点指出关键差异]
# 分析过程:"""
# VLM_Match_Context_User_Prompt="""1. 根据提供的{{image_url}}加载并查看图像并理解。
# 2. 阅读包含图像描述的正文{{user_text}}以及图像标题{{caption}},注意正文和图像理解中的作品名称、人物、实体名称、颜色、纹理、人物/实体数量、情景和角色等等可能的元素,都需要判断是否匹配。
# 不关注大范围的情节温和,在当前的文本片段和图像是否匹配,特别是图中的人物数量等。
# 判断正文的描述是否在图中相匹配。
# 3. 比较图像理解内容与用户正文匹配的片段描述,判断两者是否匹配。
# 4. 如果匹配,设置{{match_result}}为"匹配";如果不匹配,设置{{match_result}}为"不匹配"。
# 5. 根据比较结果,编写一个解释作为{{reason}},说明为什么匹配或不匹配。
# 6.若得到的信息无法判断,则认定为当前图像和标题为“匹配”。
# 7. 输出格式应严格按照以下格式:
# 匹配结果:{{match_result}}
# 原因:{{reason}}
# """
# 正在使用
VLM_Match_Context_User_Prompt="""# 图像-文本匹配判断任务
## 操作步骤
1. **观察图像**{{image_url}}:识别人物或实体数量、实体名称和物体特征、实体位置(上面、中间、下面)、颜色纹理、文字、场景环境、动作状态。
2. **解析文本**{{user_text}}和标题{{caption}}:提取关键描述要素(人物或实体数量、实体名称和物体特征、实体位置和方位、颜色纹理、文字、场景环境、动作状态。等)
3. **对比匹配**:识别人物或实体数量、实体名称和物体特征、实体位置(上面、中间、下面)、颜色、文字。
4. **判断标准**:
- **匹配**:关键要素一致,允许细节差异
- **不匹配**:实体数量不匹配、位置不匹配、实体名称错误、颜色不同、文字不符、场景根本不同,允许其他细节差异
- **信息不足时默认判定为匹配**
5. 案例信息:
case1:数量问题。图中描述只有一个、一根、一扇、一种实体。 文本中描述的是三个、三根、三扇或者三种实体。 故而,不匹配。
case2:实体名称。图中描述的实体是A(风筝),文本描述中找不到A的描述,反而在描述实体B(大树)。故而,不匹配。
case3:实体和人物混淆。图中描述的一幅作品、一个实体装饰、一群人。文中则在描述创作的人、场景、情景或者世间。故而,不匹配。
case4:实体没有包含关系。图中描述了一个场景,例如:嬉戏打闹 ,一个红灯笼,一颗大树。文中没有找到和图中描述的实体。故而,不匹配。
case5:实体属性不符。图中有一群男人。文中描述了女人。描述的实体属性错误,不包含。故而,不匹配。
case6:文字不符。图中描述的文字包括:天下一统、第K次会议等。文中描述的内容和图中描述不符,文中是天下大乱、第N次会议。故而,不匹配。
case7:颜色不符。图中描述的实体颜色和文中不符。
case8:实体位置和方位错误。图中描述的实体在某个位置上面,文中描述的是在中间放置,位置描述错误。故而,不匹配。
6. 如果匹配,设置{{match_result}}为"匹配";如果不匹配,设置{{match_result}}为"不匹配"。
7. 原因需要精简,可以有每个问题简短的分析过程!
严格遵循输出格式,不要添加#等符号:
匹配结果:{match_result}
原因:{reason}
[简要说明匹配/不匹配的具体原因,重点指出关键差异]
分析过程:"""
import requests
def test_service():
#url_path='https://help-static-aliyun-doc.aliyuncs.com/file-manage-files/zh-CN/20241022/emyrja/dog_and_girl.jpeg'
url_path='https://oss.5rs.me/oss/upload/image/png/d038e42a43154bb499810096446888f8.png'
#url_path='/home/wangtengbo/A800-13-nfs/Image_Text_Matching_Server_Develop/logs/2025-06-23/images/dog_and_girl_20250623024916296.jpeg'
data_info = {
"illustration_url": url_path,
'caption_text':'庆祝澳门回归',
"context_info":''
}
try:
url = 'http://localhost:29505/v1/image_text_matching' # 王腾博部署
url = 'http://localhost:29500/v1/image_text_matching' # 测试和生产部署的api接口
response = requests.post(url, json=data_info)
print('Comment response status code:', response.status_code)
if response.status_code == 200:
print('Response content:', response.json())
else:
print('Error response content:', response.text)
except requests.exceptions.RequestException as e:
print(f"Request failed: {e}")
if __name__ == "__main__":
test_service()
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
图文匹配检测服务
完整的Web服务器实现
"""
import http.server
import socketserver
import os
import json
import urllib.parse
from pathlib import Path
import time
import requests
import logging
# from obs_upload import OBSUploader
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
import requests
import os
import mimetypes
from typing import Dict, Optional, Union, Tuple
from urllib.parse import quote
class OBSUploader:
def __init__(self, base_url: str = "https://open.raysgo.com", auth_token: Optional[str] = None):
"""
Initialize the OBS uploader.
Args:
base_url: The base URL for the API
auth_token: The authorization token for API access
"""
self.base_url = base_url.rstrip('/')
self.auth_token = auth_token
self.headers = {
'Authorization': f'Bearer {auth_token}' if auth_token else None
}
# Initialize mimetypes
mimetypes.init()
def _get_content_type(self, file_path: Union[str, bytes]) -> Tuple[str, bytes]:
"""
Get content type and file content from file path or bytes.
Args:
file_path: Path to the file or file content as bytes
Returns:
Tuple of (content_type, file_content)
"""
if isinstance(file_path, str):
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
content_type, _ = mimetypes.guess_type(file_path)
with open(file_path, 'rb') as f:
file_content = f.read()
else:
file_content = file_path
# For bytes input, try to detect type from first few bytes
content_type = 'application/octet-stream' # Default content type
return content_type or 'application/octet-stream', file_content
def get_upload_url(self, biz_code: str, object_name: str, content_type: str) -> Dict:
"""
Get a temporary upload URL for the specified object.
Args:
biz_code: Business code for the upload
object_name: Name/path of the object to upload
content_type: MIME type of the file
Returns:
Dict containing the upload URL and related information
"""
endpoint = f"{self.base_url}/aimodel/v1.0/obs/getCreatePostSignature"
params = {
'bizCode': biz_code,
'objectName': object_name,
'mimeType': content_type
}
response = requests.get(endpoint, params=params, headers=self.headers)
response.raise_for_status()
return response.json()
def upload_file(self, file_path: Union[str, bytes], biz_code: str, object_name: str) -> Dict:
"""
Upload a file using temporary credentials.
Args:
file_path: Path to the file to upload or file content as bytes
biz_code: Business code for the upload
object_name: Name/path of the object to upload
Returns:
Dict containing the upload result and file URL
"""
# Get content type and file content
content_type, file_content = self._get_content_type(file_path)
# Get temporary upload URL with content type
upload_info = self.get_upload_url(biz_code, object_name, content_type)
if upload_info['errCode'] != 0:
raise Exception(f"Failed to get upload URL: {upload_info['message']}")
upload_url = upload_info['data']['temporarySignatureUrl']
# Upload the file with the correct content type
headers = {
'Content-Type': content_type,
'Content-Length': str(len(file_content))
}
response = requests.put(upload_url, data=file_content, headers=headers)
response.raise_for_status()
return {
'success': True,
'file_url': upload_info['data']['domain'] + '/' + object_name,
'object_url_map': upload_info['data']['objectUrlMap']
}
class ImageTextMatchingHandler(http.server.SimpleHTTPRequestHandler):
def __init__(self, *args, **kwargs):
# 设置静态文件目录
super().__init__(*args, directory="web", **kwargs)
self.uploader = OBSUploader(auth_token="dcg-4c1e3a7f4fcd415e8c93151ff539d20a")
if self.uploader:
print('uploader_sucess!')
else:
print('uploader_failed!')
def setup(self):
super().setup()
if not hasattr(self, 'uploader'):
self.uploader = OBSUploader(auth_token="dcg-4c1e3a7f4fcd415e8c93151ff539d20a")
print('uploader_initialized in setup')
def end_headers(self):
# 添加CORS头
self.send_header('Access-Control-Allow-Origin', '*')
self.send_header('Access-Control-Allow-Methods', 'GET, POST, OPTIONS')
self.send_header('Access-Control-Allow-Headers', 'Content-Type')
super().end_headers()
def do_OPTIONS(self):
# 处理预检请求
self.send_response(200)
self.end_headers()
def do_POST(self):
"""处理POST请求"""
if self.path == '/upload':
self.handle_upload()
elif self.path == '/api/image_text_matching':
self.handle_matching_api()
else:
self.send_error(404)
def do_GET(self):
"""处理GET请求"""
if self.path == '/api/health':
# 健康检查端点
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
response = {"status": "ok", "service": "image_text_matching"}
self.wfile.write(json.dumps(response).encode())
else:
# 静态文件处理
super().do_GET()
def handle_upload(self):
"""处理文件上传"""
try:
content_length = int(self.headers['Content-Length'])
post_data = self.rfile.read(content_length)
# 解析multipart/form-data
content_type = self.headers['Content-Type']
if 'multipart/form-data' in content_type:
boundary = content_type.split('boundary=')[1].encode()
parts = post_data.split(b'--' + boundary)
for part in parts:
if b'Content-Disposition: form-data' in part and b'filename=' in part:
# 提取文件名
lines = part.split(b'\r\n')
filename = "unknown.jpg"
for line in lines:
if b'Content-Disposition' in line:
try:
filename_start = line.find(b'filename="') + 10
filename_end = line.find(b'"', filename_start)
filename = line[filename_start:filename_end].decode()
except:
pass
break
# 提取文件内容
content_start = part.find(b'\r\n\r\n') + 4
content_end = part.rfind(b'\r\n')
if content_end == -1:
content_end = len(part)
file_content = part[content_start:content_end]
if len(file_content) > 0:
# 保存文件
upload_dir = Path("uploads")
upload_dir.mkdir(exist_ok=True)
timestamp = str(int(time.time() * 1000))
file_extension = filename.split('.')[-1] if '.' in filename else 'jpg'
save_filename = f"upload_{timestamp}.{file_extension}"
file_path = upload_dir / save_filename
with open(file_path, 'wb') as f:
f.write(file_content)
# Upload a file
try:
logger.info(f"file_path=/home/wangtengbo/uploads/{save_filename}")
result = self.uploader.upload_file(
file_path=f"/home/wangtengbo/uploads/{save_filename}",
biz_code="image_text_matching",
object_name=f"image/{save_filename}"
)
file_url=result['file_url']
logger.info('用户上传的图像转换为file_url={file_url}')
except Exception as e:
print(f"Upload failed: {str(e)}")
# 返回文件URL #这里需要封装一个oss
#file_url = f"http://localhost:8080/uploads/{save_filename}"
response = {"success": True, "url": file_url}
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
self.wfile.write(json.dumps(response).encode())
return
self.send_error(400, "Bad Request - No valid file found")
except Exception as e:
logger.error(f"Upload error: {e}")
self.send_error(500, f"Internal Server Error: {str(e)}")
def handle_matching_api(self):
"""处理图文匹配API请求"""
try:
content_length = int(self.headers['Content-Length'])
post_data = self.rfile.read(content_length)
# 解析JSON数据
data = json.loads(post_data.decode('utf-8'))
# 调用图文匹配检测服务
result = self.call_matching_service(data)
# 返回结果
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
self.wfile.write(json.dumps(result, ensure_ascii=False).encode('utf-8'))
except json.JSONDecodeError as e:
logger.error(f"JSON decode error: {e}")
self.send_error(400, "Invalid JSON data")
except Exception as e:
logger.error(f"Matching API error: {e}")
error_response = {
"code": 500,
"message": f"服务器内部错误: {str(e)}",
"result": []
}
self.send_response(500)
self.send_header('Content-Type', 'application/json')
self.end_headers()
self.wfile.write(json.dumps(error_response, ensure_ascii=False).encode('utf-8'))
def call_matching_service(self, data):
"""调用外部图文匹配检测服务"""
api_url = "http://localhost:29500/v1/image_text_matching"
try:
# 准备请求数据
request_data = {
"illustration_url": data.get("illustration_url", ""),
"caption_text": data.get("caption_text", ""),
"context_info": data.get("context_info", "")
}
logger.info(f"Calling matching service with data: {request_data}")
# 调用外部服务
response = requests.post(api_url, json=request_data, timeout=30)
if response.status_code == 200:
result = response.json()
logger.info(f"Matching service response: {result}")
return result
else:
logger.error(f"Matching service error: {response.status_code} - {response.text}")
return {
"code": response.status_code,
"message": f"外部服务错误: {response.text}",
"result": []
}
except requests.exceptions.RequestException as e:
logger.error(f"Request to matching service failed: {e}")
return {
"code": 503,
"message": f"无法连接到图文匹配检测服务 (localhost:29500): {str(e)}",
"result": []
}
except Exception as e:
logger.error(f"Unexpected error in matching service call: {e}")
return {
"code": 500,
"message": f"服务调用失败: {str(e)}",
"result": []
}
def setup_directories():
"""创建必要的目录结构"""
directories = ["web", "uploads"]
for directory in directories:
os.makedirs(directory, exist_ok=True)
print(f"✓ 创建目录: {directory}")
def create_index_html():
"""创建前端HTML文件"""
html_content = '''<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>图文匹配检测服务</title>
<style>
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
padding: 20px;
}
.container {
max-width: 800px;
margin: 0 auto;
background: rgba(255, 255, 255, 0.95);
border-radius: 20px;
padding: 40px;
box-shadow: 0 20px 40px rgba(0, 0, 0, 0.1);
backdrop-filter: blur(10px);
}
.header {
text-align: center;
margin-bottom: 40px;
}
.header h1 {
color: #333;
font-size: 2.5rem;
margin-bottom: 10px;
background: linear-gradient(45deg, #667eea, #764ba2);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
background-clip: text;
}
.header p {
color: #666;
font-size: 1.1rem;
}
.config-section {
background: #f8f9fa;
padding: 20px;
border-radius: 10px;
margin-bottom: 30px;
border-left: 4px solid #667eea;
}
.config-section h4 {
color: #333;
margin-bottom: 15px;
}
.config-input {
width: 100%;
padding: 10px;
border: 1px solid #ddd;
border-radius: 5px;
margin-bottom: 10px;
}
.form-section {
margin-bottom: 30px;
}
.form-section h3 {
color: #333;
margin-bottom: 15px;
font-size: 1.3rem;
display: flex;
align-items: center;
}
.form-section h3::before {
content: '';
width: 4px;
height: 20px;
background: linear-gradient(45deg, #667eea, #764ba2);
margin-right: 10px;
border-radius: 2px;
}
.upload-area {
border: 2px dashed #ccc;
border-radius: 15px;
padding: 40px;
text-align: center;
background: #fafafa;
transition: all 0.3s ease;
cursor: pointer;
position: relative;
overflow: hidden;
}
.upload-area:hover {
border-color: #667eea;
background: #f0f4ff;
transform: translateY(-2px);
box-shadow: 0 10px 20px rgba(102, 126, 234, 0.1);
}
.upload-area.dragover {
border-color: #667eea;
background: #e8f0ff;
}
.upload-icon {
font-size: 3rem;
color: #ccc;
margin-bottom: 20px;
}
.upload-text {
color: #666;
font-size: 1.1rem;
margin-bottom: 10px;
}
.upload-hint {
color: #999;
font-size: 0.9rem;
}
.file-input {
display: none;
}
.preview-container {
margin-top: 20px;
text-align: center;
}
.preview-image {
max-width: 100%;
max-height: 300px;
border-radius: 10px;
box-shadow: 0 10px 30px rgba(0, 0, 0, 0.1);
margin-bottom: 15px;
}
.file-info {
background: #f8f9fa;
padding: 15px;
border-radius: 10px;
margin-top: 15px;
}
.input-group {
margin-bottom: 25px;
}
.input-group label {
display: block;
margin-bottom: 8px;
color: #333;
font-weight: 600;
}
.input-group input,
.input-group textarea {
width: 100%;
padding: 15px;
border: 2px solid #e0e0e0;
border-radius: 10px;
font-size: 1rem;
transition: all 0.3s ease;
background: #fff;
}
.input-group input:focus,
.input-group textarea:focus {
outline: none;
border-color: #667eea;
box-shadow: 0 0 0 3px rgba(102, 126, 234, 0.1);
}
.input-group textarea {
min-height: 120px;
resize: vertical;
}
.submit-btn {
width: 100%;
padding: 18px;
background: linear-gradient(45deg, #667eea, #764ba2);
color: white;
border: none;
border-radius: 12px;
font-size: 1.2rem;
font-weight: 600;
cursor: pointer;
transition: all 0.3s ease;
position: relative;
overflow: hidden;
}
.submit-btn:hover {
transform: translateY(-2px);
box-shadow: 0 15px 30px rgba(102, 126, 234, 0.3);
}
.submit-btn:disabled {
opacity: 0.6;
cursor: not-allowed;
transform: none;
}
.loading {
display: none;
margin-left: 10px;
}
.spinner {
width: 20px;
height: 20px;
border: 2px solid #ffffff30;
border-top: 2px solid #ffffff;
border-radius: 50%;
animation: spin 1s linear infinite;
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
.result-section {
margin-top: 40px;
display: none;
}
.result-content {
background: #f8f9fa;
padding: 25px;
border-radius: 15px;
border-left: 5px solid #667eea;
}
.result-success {
border-left-color: #28a745;
background: #d4edda;
}
.result-error {
border-left-color: #dc3545;
background: #f8d7da;
}
.result-item {
background: white;
padding: 20px;
border-radius: 10px;
margin-bottom: 15px;
box-shadow: 0 5px 15px rgba(0, 0, 0, 0.05);
}
.error-type {
color: #dc3545;
font-weight: 600;
margin-bottom: 10px;
}
.error-text {
color: #333;
margin-bottom: 10px;
padding: 10px;
background: #f8f9fa;
border-radius: 5px;
}
.error-reason {
color: #666;
font-style: italic;
}
.clear-btn {
margin-top: 20px;
padding: 10px 20px;
background: #6c757d;
color: white;
border: none;
border-radius: 8px;
cursor: pointer;
transition: all 0.3s ease;
}
.clear-btn:hover {
background: #5a6268;
}
.status-indicator {
display: inline-block;
width: 10px;
height: 10px;
border-radius: 50%;
margin-right: 8px;
}
.status-online {
background: #28a745;
}
.status-offline {
background: #dc3545;
}
@media (max-width: 768px) {
.container {
padding: 20px;
margin: 10px;
}
.header h1 {
font-size: 2rem;
}
}
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>🔍 图文匹配检测服务</h1>
<p>上传图片并输入相关文本,检测图片与文本的匹配度</p>
</div>
<div class="config-section">
<h4>⚙️ 服务配置</h4>
<label>API服务地址:</label>
<input type="text" id="apiUrl" class="config-input" value="http://localhost:8080" placeholder="http://localhost:8080">
<div style="margin-top: 10px;">
<span>服务状态: </span>
<span class="status-indicator status-offline" id="statusDot"></span>
<span id="statusText">检测中...</span>
<button onclick="checkApiStatus()" style="margin-left: 10px; padding: 5px 10px; border: none; background: #667eea; color: white; border-radius: 5px; cursor: pointer;">检测连接</button>
</div>
</div>
<form id="matchingForm">
<div class="form-section">
<h3>📸 上传图片</h3>
<div class="upload-area" id="uploadArea">
<div class="upload-icon">📁</div>
<div class="upload-text">点击或拖拽上传图片</div>
<div class="upload-hint">支持 JPG、PNG、JPEG 格式</div>
<input type="file" id="imageFile" class="file-input" accept="image/*" required>
</div>
<div class="preview-container" id="previewContainer" style="display: none;">
<img id="previewImage" class="preview-image" alt="预览图片">
<div class="file-info" id="fileInfo"></div>
</div>
</div>
<div class="form-section">
<h3>✏️ 文本信息</h3>
<div class="input-group">
<label for="captionText">图片标题/描述文本:</label>
<input type="text" id="captionText" placeholder="请输入与图片相关的标题或描述文本" required>
</div>
<div class="input-group">
<label for="contextInfo">上下文信息(可选):</label>
<textarea id="contextInfo" placeholder="请输入相关的上下文信息,可以为空"></textarea>
</div>
</div>
<button type="submit" class="submit-btn" id="submitBtn">
🚀 开始检测
<div class="loading" id="loading">
<div class="spinner"></div>
</div>
</button>
</form>
<div class="result-section" id="resultSection">
<div class="form-section">
<h3>📊 检测结果</h3>
<div class="result-content" id="resultContent"></div>
<button type="button" class="clear-btn" onclick="clearResults()">清除结果</button>
</div>
</div>
</div>
<script>
const uploadArea = document.getElementById('uploadArea');
const fileInput = document.getElementById('imageFile');
const previewContainer = document.getElementById('previewContainer');
const previewImage = document.getElementById('previewImage');
const fileInfo = document.getElementById('fileInfo');
const form = document.getElementById('matchingForm');
const submitBtn = document.getElementById('submitBtn');
const loading = document.getElementById('loading');
const resultSection = document.getElementById('resultSection');
const resultContent = document.getElementById('resultContent');
const apiUrlInput = document.getElementById('apiUrl');
const statusDot = document.getElementById('statusDot');
const statusText = document.getElementById('statusText');
// 页面加载时检测API状态
window.onload = function() {
checkApiStatus();
};
// 检测API状态
async function checkApiStatus() {
const apiUrl = apiUrlInput.value.trim();
statusText.textContent = '检测中...';
statusDot.className = 'status-indicator status-offline';
try {
const response = await fetch(`${apiUrl}/api/health`, {
method: 'GET',
timeout: 5000
});
if (response.ok) {
statusText.textContent = '服务在线';
statusDot.className = 'status-indicator status-online';
} else {
statusText.textContent = '服务离线';
statusDot.className = 'status-indicator status-offline';
}
} catch (error) {
statusText.textContent = '连接失败';
statusDot.className = 'status-indicator status-offline';
}
}
// 上传区域点击事件
uploadArea.addEventListener('click', () => {
fileInput.click();
});
// 拖拽上传功能
uploadArea.addEventListener('dragover', (e) => {
e.preventDefault();
uploadArea.classList.add('dragover');
});
uploadArea.addEventListener('dragleave', () => {
uploadArea.classList.remove('dragover');
});
uploadArea.addEventListener('drop', (e) => {
e.preventDefault();
uploadArea.classList.remove('dragover');
const files = e.dataTransfer.files;
if (files.length > 0) {
handleFileSelect(files[0]);
}
});
// 文件选择事件
fileInput.addEventListener('change', (e) => {
if (e.target.files.length > 0) {
handleFileSelect(e.target.files[0]);
}
});
// 处理文件选择
function handleFileSelect(file) {
if (!file.type.startsWith('image/')) {
alert('请选择图片文件!');
return;
}
// 显示预览
const reader = new FileReader();
reader.onload = (e) => {
previewImage.src = e.target.result;
previewContainer.style.display = 'block';
// 显示文件信息
fileInfo.innerHTML = `
<strong>文件名:</strong>${file.name}<br>
<strong>文件大小:</strong>${(file.size / 1024 / 1024).toFixed(2)} MB<br>
<strong>文件类型:</strong>${file.type}
`;
};
reader.readAsDataURL(file);
// 将文件对象保存到input中
const dt = new DataTransfer();
dt.items.add(file);
fileInput.files = dt.files;
}
// 上传图片
async function uploadImage(file) {
const formData = new FormData();
formData.append('image', file);
try {
const response = await fetch('/upload', {
method: 'POST',
body: formData
});
if (!response.ok) {
throw new Error(`上传失败: ${response.status}`);
}
const result = await response.json();
return result.url;
} catch (error) {
console.error('上传错误:', error);
throw error;
}
}
// 表单提交事件
form.addEventListener('submit', async (e) => {
e.preventDefault();
const file = fileInput.files[0];
const captionText = document.getElementById('captionText').value.trim();
const contextInfo = document.getElementById('contextInfo').value.trim();
const apiUrl = apiUrlInput.value.trim();
if (!file) {
alert('请选择图片文件!');
return;
}
if (!captionText) {
alert('请输入图片标题或描述文本!');
return;
}
if (!apiUrl) {
alert('请输入API服务地址!');
return;
}
// 显示加载状态
submitBtn.disabled = true;
loading.style.display = 'inline-block';
resultSection.style.display = 'none';
try {
// 上传图片获取URL
console.log('开始上传图片...');
const imageUrl = await uploadImage(file);
console.log('图片上传成功:', imageUrl);
// 调用图文匹配API
console.log('开始调用图文匹配API...');
const response = await fetch(`${apiUrl}/api/image_text_matching`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
illustration_url: imageUrl,
caption_text: captionText,
context_info: contextInfo
})
});
const result = await response.json();
console.log('API响应结果:', result);
displayResult(result);
} catch (error) {
console.error('处理失败:', error);
displayError(`处理失败: ${error.message}`);
} finally {
// 隐藏加载状态
submitBtn.disabled = false;
loading.style.display = 'none';
}
});
// 显示结果
function displayResult(result) {
resultSection.style.display = 'block';
if (result.code === 200) {
if (result.result && result.result.length > 0) {
// 有匹配问题
resultContent.className = 'result-content result-error';
let html = '<h4>❌ 检测到图文不匹配问题:</h4>';
result.result.forEach((item, index) => {
html += `
<div class="result-item">
<div class="error-type">错误类型: ${item.error_type}</div>
<div class="error-text"><strong>文本内容:</strong> "${item.context_text}"</div>
<div class="error-reason"><strong>不匹配原因:</strong> ${item.error_reson}</div>
</div>
`;
});
resultContent.innerHTML = html;
} else {
// 匹配成功
resultContent.className = 'result-content result-success';
resultContent.innerHTML = `
<h4>✅ 图文匹配检测通过</h4>
<p>图片与文本内容匹配度良好,未发现明显的不匹配问题。</p>
<div class="result-item">
<strong>检测模型:</strong> ${result.model || 'qwen'}<br>
<strong>检测时间:</strong> ${new Date().toLocaleString()}
</div>
`;
}
} else {
// API错误
resultContent.className = 'result-content result-error';
resultContent.innerHTML = `
<h4>❌ 检测失败</h4>
<p><strong>错误信息:</strong> ${result.message}</p>
<p><strong>错误代码:</strong> ${result.code}</p>
`;
}
}
// 显示错误信息
function displayError(message) {
resultSection.style.display = 'block';
resultContent.className = 'result-content result-error';
resultContent.innerHTML = `
<h4>❌ 处理失败</h4>
<p>${message}</p>
`;
}
// 清除结果
function clearResults() {
resultSection.style.display = 'none';
resultContent.innerHTML = '';
}
</script>
</body>
</html>'''
with open("web/index.html", "w", encoding="utf-8") as f:
f.write(html_content)
print("✓ 创建前端文件: web/index.html")
def start_server(port=8080):
"""启动Web服务器"""
print("正在初始化图文匹配检测服务...")
# 设置目录结构
setup_directories()
# 创建前端文件
create_index_html()
print("\n" + "="*50)
print("🚀 图文匹配检测服务启动中...")
print("="*50)
# 启动服务器
with socketserver.TCPServer(("", port), ImageTextMatchingHandler) as httpd:
server_address = ("", port)
print(f"✓ 服务器绑定地址: http://localhost:{port}")
print(f"✓ 静态文件目录: web/")
print(f"✓ 上传文件目录: uploads/")
print(f"✓ 外部API地址: http://localhost:29500")
try:
print(f"\n🌐 服务器启动成功!")
print(f"📱 请在浏览器中访问: http://localhost:{port}")
print(f"🔧 管理界面: http://localhost:{port}/api/health")
print("\n" + "="*50)
print("按 Ctrl+C 停止服务器")
print("="*50)
httpd.serve_forever()
except KeyboardInterrupt:
print("\n\n⏹️ 正在停止服务器...")
httpd.shutdown()
print("✓ 服务器已停止")
if __name__ == "__main__":
import sys
# 解析命令行参数
port = 8080
if len(sys.argv) > 1:
try:
port = int(sys.argv[1])
except ValueError:
print("❌ 端口号必须是数字")
sys.exit(1)
# 启动服务器
start_server(port)
\ No newline at end of file
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
图文匹配检测服务
完整的Web服务器实现
"""
import http.server
import socketserver
import os
import json
import urllib.parse
from pathlib import Path
import time
import requests
import logging
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ImageTextMatchingHandler(http.server.SimpleHTTPRequestHandler):
def __init__(self, *args, **kwargs):
# 设置静态文件目录
super().__init__(*args, directory="web", **kwargs)
def end_headers(self):
# 添加CORS头
self.send_header('Access-Control-Allow-Origin', '*')
self.send_header('Access-Control-Allow-Methods', 'GET, POST, OPTIONS')
self.send_header('Access-Control-Allow-Headers', 'Content-Type')
super().end_headers()
def do_OPTIONS(self):
# 处理预检请求
self.send_response(200)
self.end_headers()
def do_POST(self):
"""处理POST请求"""
if self.path == '/upload':
self.handle_upload()
elif self.path == '/api/image_text_matching':
self.handle_matching_api()
else:
self.send_error(404)
def do_GET(self):
"""处理GET请求"""
if self.path == '/api/health':
# 健康检查端点
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
response = {"status": "ok", "service": "image_text_matching"}
self.wfile.write(json.dumps(response).encode())
else:
# 静态文件处理
super().do_GET()
def handle_upload(self):
"""处理文件上传"""
try:
content_length = int(self.headers['Content-Length'])
post_data = self.rfile.read(content_length)
# 解析multipart/form-data
content_type = self.headers['Content-Type']
if 'multipart/form-data' in content_type:
boundary = content_type.split('boundary=')[1].encode()
parts = post_data.split(b'--' + boundary)
for part in parts:
if b'Content-Disposition: form-data' in part and b'filename=' in part:
# 提取文件名
lines = part.split(b'\r\n')
filename = "unknown.jpg"
for line in lines:
if b'Content-Disposition' in line:
try:
filename_start = line.find(b'filename="') + 10
filename_end = line.find(b'"', filename_start)
filename = line[filename_start:filename_end].decode()
except:
pass
break
# 提取文件内容
content_start = part.find(b'\r\n\r\n') + 4
content_end = part.rfind(b'\r\n')
if content_end == -1:
content_end = len(part)
file_content = part[content_start:content_end]
if len(file_content) > 0:
# 保存文件
upload_dir = Path("uploads")
upload_dir.mkdir(exist_ok=True)
timestamp = str(int(time.time() * 1000))
file_extension = filename.split('.')[-1] if '.' in filename else 'jpg'
save_filename = f"upload_{timestamp}.{file_extension}"
file_path = upload_dir / save_filename
with open(file_path, 'wb') as f:
f.write(file_content)
# 返回文件URL
file_url = f"http://localhost:8080/uploads/{save_filename}"
response = {"success": True, "url": file_url}
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
self.wfile.write(json.dumps(response).encode())
return
self.send_error(400, "Bad Request - No valid file found")
except Exception as e:
logger.error(f"Upload error: {e}")
self.send_error(500, f"Internal Server Error: {str(e)}")
def handle_matching_api(self):
"""处理图文匹配API请求"""
try:
content_length = int(self.headers['Content-Length'])
post_data = self.rfile.read(content_length)
# 解析JSON数据
data = json.loads(post_data.decode('utf-8'))
# 调用图文匹配检测服务
result = self.call_matching_service(data)
# 返回结果
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
self.wfile.write(json.dumps(result, ensure_ascii=False).encode('utf-8'))
except json.JSONDecodeError as e:
logger.error(f"JSON decode error: {e}")
self.send_error(400, "Invalid JSON data")
except Exception as e:
logger.error(f"Matching API error: {e}")
error_response = {
"code": 500,
"message": f"服务器内部错误: {str(e)}",
"result": []
}
self.send_response(500)
self.send_header('Content-Type', 'application/json')
self.end_headers()
self.wfile.write(json.dumps(error_response, ensure_ascii=False).encode('utf-8'))
def call_matching_service(self, data):
"""调用外部图文匹配检测服务"""
api_url = "http://localhost:29500/v1/image_text_matching"
try:
# 准备请求数据
request_data = {
"illustration_url": data.get("illustration_url", ""),
"caption_text": data.get("caption_text", ""),
"context_info": data.get("context_info", "")
}
logger.info(f"Calling matching service with data: {request_data}")
# 调用外部服务
response = requests.post(api_url, json=request_data, timeout=30)
if response.status_code == 200:
result = response.json()
logger.info(f"Matching service response: {result}")
return result
else:
logger.error(f"Matching service error: {response.status_code} - {response.text}")
return {
"code": response.status_code,
"message": f"外部服务错误: {response.text}",
"result": []
}
except requests.exceptions.RequestException as e:
logger.error(f"Request to matching service failed: {e}")
return {
"code": 503,
"message": f"无法连接到图文匹配检测服务 (localhost:29500): {str(e)}",
"result": []
}
except Exception as e:
logger.error(f"Unexpected error in matching service call: {e}")
return {
"code": 500,
"message": f"服务调用失败: {str(e)}",
"result": []
}
def setup_directories():
"""创建必要的目录结构"""
directories = ["web", "uploads"]
for directory in directories:
os.makedirs(directory, exist_ok=True)
print(f"✓ 创建目录: {directory}")
def create_index_html():
"""创建前端HTML文件"""
html_content = '''<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>图文匹配检测服务</title>
<style>
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
padding: 20px;
}
.container {
max-width: 800px;
margin: 0 auto;
background: rgba(255, 255, 255, 0.95);
border-radius: 20px;
padding: 40px;
box-shadow: 0 20px 40px rgba(0, 0, 0, 0.1);
backdrop-filter: blur(10px);
}
.header {
text-align: center;
margin-bottom: 40px;
}
.header h1 {
color: #333;
font-size: 2.5rem;
margin-bottom: 10px;
background: linear-gradient(45deg, #667eea, #764ba2);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
background-clip: text;
}
.header p {
color: #666;
font-size: 1.1rem;
}
.config-section {
background: #f8f9fa;
padding: 20px;
border-radius: 10px;
margin-bottom: 30px;
border-left: 4px solid #667eea;
}
.config-section h4 {
color: #333;
margin-bottom: 15px;
}
.config-input {
width: 100%;
padding: 10px;
border: 1px solid #ddd;
border-radius: 5px;
margin-bottom: 10px;
}
.form-section {
margin-bottom: 30px;
}
.form-section h3 {
color: #333;
margin-bottom: 15px;
font-size: 1.3rem;
display: flex;
align-items: center;
}
.form-section h3::before {
content: '';
width: 4px;
height: 20px;
background: linear-gradient(45deg, #667eea, #764ba2);
margin-right: 10px;
border-radius: 2px;
}
.upload-area {
border: 2px dashed #ccc;
border-radius: 15px;
padding: 40px;
text-align: center;
background: #fafafa;
transition: all 0.3s ease;
cursor: pointer;
position: relative;
overflow: hidden;
}
.upload-area:hover {
border-color: #667eea;
background: #f0f4ff;
transform: translateY(-2px);
box-shadow: 0 10px 20px rgba(102, 126, 234, 0.1);
}
.upload-area.dragover {
border-color: #667eea;
background: #e8f0ff;
}
.upload-icon {
font-size: 3rem;
color: #ccc;
margin-bottom: 20px;
}
.upload-text {
color: #666;
font-size: 1.1rem;
margin-bottom: 10px;
}
.upload-hint {
color: #999;
font-size: 0.9rem;
}
.file-input {
display: none;
}
.preview-container {
margin-top: 20px;
text-align: center;
}
.preview-image {
max-width: 100%;
max-height: 300px;
border-radius: 10px;
box-shadow: 0 10px 30px rgba(0, 0, 0, 0.1);
margin-bottom: 15px;
}
.file-info {
background: #f8f9fa;
padding: 15px;
border-radius: 10px;
margin-top: 15px;
}
.input-group {
margin-bottom: 25px;
}
.input-group label {
display: block;
margin-bottom: 8px;
color: #333;
font-weight: 600;
}
.input-group input,
.input-group textarea {
width: 100%;
padding: 15px;
border: 2px solid #e0e0e0;
border-radius: 10px;
font-size: 1rem;
transition: all 0.3s ease;
background: #fff;
}
.input-group input:focus,
.input-group textarea:focus {
outline: none;
border-color: #667eea;
box-shadow: 0 0 0 3px rgba(102, 126, 234, 0.1);
}
.input-group textarea {
min-height: 120px;
resize: vertical;
}
.submit-btn {
width: 100%;
padding: 18px;
background: linear-gradient(45deg, #667eea, #764ba2);
color: white;
border: none;
border-radius: 12px;
font-size: 1.2rem;
font-weight: 600;
cursor: pointer;
transition: all 0.3s ease;
position: relative;
overflow: hidden;
}
.submit-btn:hover {
transform: translateY(-2px);
box-shadow: 0 15px 30px rgba(102, 126, 234, 0.3);
}
.submit-btn:disabled {
opacity: 0.6;
cursor: not-allowed;
transform: none;
}
.loading {
display: none;
margin-left: 10px;
}
.spinner {
width: 20px;
height: 20px;
border: 2px solid #ffffff30;
border-top: 2px solid #ffffff;
border-radius: 50%;
animation: spin 1s linear infinite;
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
.result-section {
margin-top: 40px;
display: none;
}
.result-content {
background: #f8f9fa;
padding: 25px;
border-radius: 15px;
border-left: 5px solid #667eea;
}
.result-success {
border-left-color: #28a745;
background: #d4edda;
}
.result-error {
border-left-color: #dc3545;
background: #f8d7da;
}
.result-item {
background: white;
padding: 20px;
border-radius: 10px;
margin-bottom: 15px;
box-shadow: 0 5px 15px rgba(0, 0, 0, 0.05);
}
.error-type {
color: #dc3545;
font-weight: 600;
margin-bottom: 10px;
}
.error-text {
color: #333;
margin-bottom: 10px;
padding: 10px;
background: #f8f9fa;
border-radius: 5px;
}
.error-reason {
color: #666;
font-style: italic;
}
.clear-btn {
margin-top: 20px;
padding: 10px 20px;
background: #6c757d;
color: white;
border: none;
border-radius: 8px;
cursor: pointer;
transition: all 0.3s ease;
}
.clear-btn:hover {
background: #5a6268;
}
.status-indicator {
display: inline-block;
width: 10px;
height: 10px;
border-radius: 50%;
margin-right: 8px;
}
.status-online {
background: #28a745;
}
.status-offline {
background: #dc3545;
}
@media (max-width: 768px) {
.container {
padding: 20px;
margin: 10px;
}
.header h1 {
font-size: 2rem;
}
}
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>🔍 图文匹配检测服务</h1>
<p>上传图片并输入相关文本,检测图片与文本的匹配度</p>
</div>
<div class="config-section">
<h4>⚙️ 服务配置</h4>
<label>API服务地址:</label>
<input type="text" id="apiUrl" class="config-input" value="http://localhost:8080" placeholder="http://localhost:8080">
<div style="margin-top: 10px;">
<span>服务状态: </span>
<span class="status-indicator status-offline" id="statusDot"></span>
<span id="statusText">检测中...</span>
<button onclick="checkApiStatus()" style="margin-left: 10px; padding: 5px 10px; border: none; background: #667eea; color: white; border-radius: 5px; cursor: pointer;">检测连接</button>
</div>
</div>
<form id="matchingForm">
<div class="form-section">
<h3>📸 上传图片</h3>
<div class="upload-area" id="uploadArea">
<div class="upload-icon">📁</div>
<div class="upload-text">点击或拖拽上传图片</div>
<div class="upload-hint">支持 JPG、PNG、JPEG 格式</div>
<input type="file" id="imageFile" class="file-input" accept="image/*" required>
</div>
<div class="preview-container" id="previewContainer" style="display: none;">
<img id="previewImage" class="preview-image" alt="预览图片">
<div class="file-info" id="fileInfo"></div>
</div>
</div>
<div class="form-section">
<h3>✏️ 文本信息</h3>
<div class="input-group">
<label for="captionText">图片标题/描述文本:</label>
<input type="text" id="captionText" placeholder="请输入与图片相关的标题或描述文本" required>
</div>
<div class="input-group">
<label for="contextInfo">上下文信息(可选):</label>
<textarea id="contextInfo" placeholder="请输入相关的上下文信息,可以为空"></textarea>
</div>
</div>
<button type="submit" class="submit-btn" id="submitBtn">
🚀 开始检测
<div class="loading" id="loading">
<div class="spinner"></div>
</div>
</button>
</form>
<div class="result-section" id="resultSection">
<div class="form-section">
<h3>📊 检测结果</h3>
<div class="result-content" id="resultContent"></div>
<button type="button" class="clear-btn" onclick="clearResults()">清除结果</button>
</div>
</div>
</div>
<script>
const uploadArea = document.getElementById('uploadArea');
const fileInput = document.getElementById('imageFile');
const previewContainer = document.getElementById('previewContainer');
const previewImage = document.getElementById('previewImage');
const fileInfo = document.getElementById('fileInfo');
const form = document.getElementById('matchingForm');
const submitBtn = document.getElementById('submitBtn');
const loading = document.getElementById('loading');
const resultSection = document.getElementById('resultSection');
const resultContent = document.getElementById('resultContent');
const apiUrlInput = document.getElementById('apiUrl');
const statusDot = document.getElementById('statusDot');
const statusText = document.getElementById('statusText');
// 页面加载时检测API状态
window.onload = function() {
checkApiStatus();
};
// 检测API状态
async function checkApiStatus() {
const apiUrl = apiUrlInput.value.trim();
statusText.textContent = '检测中...';
statusDot.className = 'status-indicator status-offline';
try {
const response = await fetch(`${apiUrl}/api/health`, {
method: 'GET',
timeout: 5000
});
if (response.ok) {
statusText.textContent = '服务在线';
statusDot.className = 'status-indicator status-online';
} else {
statusText.textContent = '服务离线';
statusDot.className = 'status-indicator status-offline';
}
} catch (error) {
statusText.textContent = '连接失败';
statusDot.className = 'status-indicator status-offline';
}
}
// 上传区域点击事件
uploadArea.addEventListener('click', () => {
fileInput.click();
});
// 拖拽上传功能
uploadArea.addEventListener('dragover', (e) => {
e.preventDefault();
uploadArea.classList.add('dragover');
});
uploadArea.addEventListener('dragleave', () => {
uploadArea.classList.remove('dragover');
});
uploadArea.addEventListener('drop', (e) => {
e.preventDefault();
uploadArea.classList.remove('dragover');
const files = e.dataTransfer.files;
if (files.length > 0) {
handleFileSelect(files[0]);
}
});
// 文件选择事件
fileInput.addEventListener('change', (e) => {
if (e.target.files.length > 0) {
handleFileSelect(e.target.files[0]);
}
});
// 处理文件选择
function handleFileSelect(file) {
if (!file.type.startsWith('image/')) {
alert('请选择图片文件!');
return;
}
// 显示预览
const reader = new FileReader();
reader.onload = (e) => {
previewImage.src = e.target.result;
previewContainer.style.display = 'block';
// 显示文件信息
fileInfo.innerHTML = `
<strong>文件名:</strong>${file.name}<br>
<strong>文件大小:</strong>${(file.size / 1024 / 1024).toFixed(2)} MB<br>
<strong>文件类型:</strong>${file.type}
`;
};
reader.readAsDataURL(file);
// 将文件对象保存到input中
const dt = new DataTransfer();
dt.items.add(file);
fileInput.files = dt.files;
}
// 上传图片
async function uploadImage(file) {
const formData = new FormData();
formData.append('image', file);
try {
const response = await fetch('/upload', {
method: 'POST',
body: formData
});
if (!response.ok) {
throw new Error(`上传失败: ${response.status}`);
}
const result = await response.json();
return result.url;
} catch (error) {
console.error('上传错误:', error);
throw error;
}
}
// 表单提交事件
form.addEventListener('submit', async (e) => {
e.preventDefault();
const file = fileInput.files[0];
const captionText = document.getElementById('captionText').value.trim();
const contextInfo = document.getElementById('contextInfo').value.trim();
const apiUrl = apiUrlInput.value.trim();
if (!file) {
alert('请选择图片文件!');
return;
}
if (!captionText) {
alert('请输入图片标题或描述文本!');
return;
}
if (!apiUrl) {
alert('请输入API服务地址!');
return;
}
// 显示加载状态
submitBtn.disabled = true;
loading.style.display = 'inline-block';
resultSection.style.display = 'none';
try {
// 上传图片获取URL
console.log('开始上传图片...');
const imageUrl = await uploadImage(file);
console.log('图片上传成功:', imageUrl);
// 调用图文匹配API
console.log('开始调用图文匹配API...');
const response = await fetch(`${apiUrl}/api/image_text_matching`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
illustration_url: imageUrl,
caption_text: captionText,
context_info: contextInfo
})
});
const result = await response.json();
console.log('API响应结果:', result);
displayResult(result);
} catch (error) {
console.error('处理失败:', error);
displayError(`处理失败: ${error.message}`);
} finally {
// 隐藏加载状态
submitBtn.disabled = false;
loading.style.display = 'none';
}
});
// 显示结果
function displayResult(result) {
resultSection.style.display = 'block';
if (result.code === 200) {
if (result.result && result.result.length > 0) {
// 有匹配问题
resultContent.className = 'result-content result-error';
let html = '<h4>❌ 检测到图文不匹配问题:</h4>';
result.result.forEach((item, index) => {
html += `
<div class="result-item">
<div class="error-type">错误类型: ${item.error_type}</div>
<div class="error-text"><strong>文本内容:</strong> "${item.context_text}"</div>
<div class="error-reason"><strong>不匹配原因:</strong> ${item.error_reson}</div>
</div>
`;
});
resultContent.innerHTML = html;
} else {
// 匹配成功
resultContent.className = 'result-content result-success';
resultContent.innerHTML = `
<h4>✅ 图文匹配检测通过</h4>
<p>图片与文本内容匹配度良好,未发现明显的不匹配问题。</p>
<div class="result-item">
<strong>检测模型:</strong> ${result.model || 'qwen'}<br>
<strong>检测时间:</strong> ${new Date().toLocaleString()}
</div>
`;
}
} else {
// API错误
resultContent.className = 'result-content result-error';
resultContent.innerHTML = `
<h4>❌ 检测失败</h4>
<p><strong>错误信息:</strong> ${result.message}</p>
<p><strong>错误代码:</strong> ${result.code}</p>
`;
}
}
// 显示错误信息
function displayError(message) {
resultSection.style.display = 'block';
resultContent.className = 'result-content result-error';
resultContent.innerHTML = `
<h4>❌ 处理失败</h4>
<p>${message}</p>
`;
}
// 清除结果
function clearResults() {
resultSection.style.display = 'none';
resultContent.innerHTML = '';
}
</script>
</body>
</html>'''
with open("web/index.html", "w", encoding="utf-8") as f:
f.write(html_content)
print("✓ 创建前端文件: web/index.html")
def start_server(port=8080):
"""启动Web服务器"""
print("正在初始化图文匹配检测服务...")
# 设置目录结构
setup_directories()
# 创建前端文件
create_index_html()
print("\n" + "="*50)
print("🚀 图文匹配检测服务启动中...")
print("="*50)
# 启动服务器
with socketserver.TCPServer(("", port), ImageTextMatchingHandler) as httpd:
server_address = ("", port)
print(f"✓ 服务器绑定地址: http://localhost:{port}")
print(f"✓ 静态文件目录: web/")
print(f"✓ 上传文件目录: uploads/")
print(f"✓ 外部API地址: http://localhost:29500")
try:
print(f"\n🌐 服务器启动成功!")
print(f"📱 请在浏览器中访问: http://localhost:{port}")
print(f"🔧 管理界面: http://localhost:{port}/api/health")
print("\n" + "="*50)
print("按 Ctrl+C 停止服务器")
print("="*50)
httpd.serve_forever()
except KeyboardInterrupt:
print("\n\n⏹️ 正在停止服务器...")
httpd.shutdown()
print("✓ 服务器已停止")
if __name__ == "__main__":
import sys
# 解析命令行参数
port = 8080
if len(sys.argv) > 1:
try:
port = int(sys.argv[1])
except ValueError:
print("❌ 端口号必须是数字")
sys.exit(1)
# 启动服务器
start_server(port)
\ No newline at end of file
import requests
import os
import mimetypes
from typing import Dict, Optional, Union, Tuple
from urllib.parse import quote
class OBSUploader:
def __init__(self, base_url: str = "https://open.raysgo.com", auth_token: Optional[str] = None):
"""
Initialize the OBS uploader.
Args:
base_url: The base URL for the API
auth_token: The authorization token for API access
"""
self.base_url = base_url.rstrip('/')
self.auth_token = auth_token
self.headers = {
'Authorization': f'Bearer {auth_token}' if auth_token else None
}
# Initialize mimetypes
mimetypes.init()
def _get_content_type(self, file_path: Union[str, bytes]) -> Tuple[str, bytes]:
"""
Get content type and file content from file path or bytes.
Args:
file_path: Path to the file or file content as bytes
Returns:
Tuple of (content_type, file_content)
"""
if isinstance(file_path, str):
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
content_type, _ = mimetypes.guess_type(file_path)
with open(file_path, 'rb') as f:
file_content = f.read()
else:
file_content = file_path
# For bytes input, try to detect type from first few bytes
content_type = 'application/octet-stream' # Default content type
return content_type or 'application/octet-stream', file_content
def get_upload_url(self, biz_code: str, object_name: str, content_type: str) -> Dict:
"""
Get a temporary upload URL for the specified object.
Args:
biz_code: Business code for the upload
object_name: Name/path of the object to upload
content_type: MIME type of the file
Returns:
Dict containing the upload URL and related information
"""
endpoint = f"{self.base_url}/aimodel/v1.0/obs/getCreatePostSignature"
params = {
'bizCode': biz_code,
'objectName': object_name,
'mimeType': content_type
}
response = requests.get(endpoint, params=params, headers=self.headers)
response.raise_for_status()
return response.json()
def upload_file(self, file_path: Union[str, bytes], biz_code: str, object_name: str) -> Dict:
"""
Upload a file using temporary credentials.
Args:
file_path: Path to the file to upload or file content as bytes
biz_code: Business code for the upload
object_name: Name/path of the object to upload
Returns:
Dict containing the upload result and file URL
"""
# Get content type and file content
content_type, file_content = self._get_content_type(file_path)
# Get temporary upload URL with content type
upload_info = self.get_upload_url(biz_code, object_name, content_type)
if upload_info['errCode'] != 0:
raise Exception(f"Failed to get upload URL: {upload_info['message']}")
upload_url = upload_info['data']['temporarySignatureUrl']
# Upload the file with the correct content type
headers = {
'Content-Type': content_type,
'Content-Length': str(len(file_content))
}
response = requests.put(upload_url, data=file_content, headers=headers)
response.raise_for_status()
return {
'success': True,
'file_url': upload_info['data']['domain'] + '/' + object_name,
'object_url_map': upload_info['data']['objectUrlMap']
}
# Example usage:
if __name__ == "__main__":
# Initialize uploader
uploader = OBSUploader(auth_token="dcg-4c1e3a7f4fcd415e8c93151ff539d20a")
# Upload a file
try:
result = uploader.upload_file(
file_path="/data/wangtengbo/formula_node4_测试/logs/logs/2025-03-02/images/0d307e97071846a1b144e7dfb4d44241_20250302073213192/formula_1.png",
biz_code="formula",
object_name="image/test.jpg"
)
print(f"File uploaded successfully! URL: {result['file_url']}")
except Exception as e:
print(f"Upload failed: {str(e)}")
\ No newline at end of file
import asyncio
import uvicorn
import threading
import json
import time
import sys
import os
import json
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from loguru import logger
from utils.common import get_millisecond_time,save_logs_to_file
import requests
import re
from pydantic import BaseModel
from config.config import DATA_LOGS_DATA,CHECK_IMAGE_CAPTION,CHECK_IMAGE_CAPTION_VLM,CHECK_IMAGE_CONTEXT,CHECK_IMAGE_CONTEXT_VLM
from config.config import VLM_Match_User_Prompt,VLM_Match_Context_User_Prompt
from datetime import datetime
from tasks.qwen_vl_infer import qwen_vl_infer
# 移除默认的日志处理器
logger.remove()
# 添加新的日志处理器,按天生成新日志文件
logger.add(
sink="/nfs/liuxin/work/Image_TextTitle_Matching/Image_Text_Matching_Server_Pro/logs/image_text_match_{time:YYYY-MM-DD}.log", # 日志文件按日期命名
format="{time:YYYY-MM-DD HH:mm:ss} {process} {level} {module}:{function}:{line}: {message}",
rotation="1 day", # 每天创建一个新文件
encoding="utf-8", # 日志文件编码
enqueue=True, # 异步写入
retention="90 days" # 保留最近 90 天的日志文件
)
app = FastAPI()
class Item(BaseModel):
illustration_url:str
caption_text:str = ""
context_info:str = ""
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
def read_root():
logger.info("Root endpoint was accessed")
return {"Hello": "World"}
@app.get('/health')
def health():
return "OK"
## 不调用版面检测接口
@app.post("/v1/image_text_matching")
def formula_parse_by_tex(item:Item):
logger.info('in image_text_matching')
#logger.info(f'item.illustration_url={item.illustration_url}\nitem.caption_text={item.caption_text}\ncontext_info={item.context_info}')
try:
#内容存储
content = {"result":[]}
#数据飞轮
logs_info={'url':'',
'image_path':'',
'caption_text':'',
'context_info':'',
'match_context_info':'',
'caption_check':'',
'context_check':''
}
if item.illustration_url is not None and len(item.illustration_url.strip())<=0:
logger.error("requests url is None: {0}")
content['code'] = 400
content['message'] = 'the url is empty!'
# content['base64_images'] = []
# content['result'] = {"error_formula":[],"error_reason":[],"corrected_formula":[]}
return JSONResponse(content=content)
if len(item.caption_text)<=0 and len(item.context_info)<=0:
logger.error("caption and context all None: {0}")
content['code'] = 400
content['message'] = 'caption and context all Empty:!'
# content['base64_images'] = []
# content['result'] = {"error_formula":[],"error_reason":[],"corrected_formula":[]}
return JSONResponse(content=content)
url = item.illustration_url.strip()
logs_info['url']=url
# 保存文件 #这里需要判断是否这个url连接的后缀是否包含特定元素
task_id = get_millisecond_time()
suffix=None
apart_file_url= url.split('/')[-1]
if '.' in apart_file_url:
file_name,suffix = url.split('/')[-1].split('.')
else:
file_name=apart_file_url
if suffix==None or suffix not in ['jpg','png','jpeg']:
suffix='jpg'
#创建日志文件夹
current_date = datetime.now().strftime("%Y-%m-%d")
daily_log_dir = os.path.join(DATA_LOGS_DATA, current_date)
# 创建当天的日志目录
os.makedirs(daily_log_dir, exist_ok=True)
#基于每日文件夹创建子文件存储,image,json,sub_img
json_save_dir=os.path.join(daily_log_dir,'jsons')
os.makedirs(json_save_dir, exist_ok=True)
image_save_dir=os.path.join(daily_log_dir,'images')
os.makedirs(image_save_dir, exist_ok=True)
#文件名存储
file_name=file_name+'_'+task_id
save_json_path=os.path.join(json_save_dir, file_name+'.json')
#每毫秒一个文件名,防止数据覆盖
save_file_path= os.path.join(image_save_dir, file_name+'.'+suffix)
logs_info['image_path']=save_file_path
#创建存储的文件,只会生效一次
start = time.time()
# 通过url下载文件
proxies = {"http": None, "https": None} # 绕过代理
response = requests.get(url, proxies=proxies, stream=True)
if response.status_code != 200:
logger.error("下载文件失败: {0}", url)
content['code'] = 400
content['message'] = f"下载文件失败: {url}"
content["result"] = []
return JSONResponse(content=content)
## 图像数据存储
with open(save_file_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=10240):
if chunk:
f.write(chunk)
logger.info(f'图像存储成功,path={save_file_path}')
end = time.time()
cost = end - start
logger.info("下载文件:\t{0}\t时间为:{1}秒", url, cost)
## 日志飞轮
logs_info['url']=item.illustration_url
logs_info['caption_text']=item.caption_text
logs_info['context_info']=item.context_info
logs_info['image_path']=save_file_path
### 一、图像和标题匹配
if CHECK_IMAGE_CAPTION and len(item.caption_text)>3:
error_dict={'context_text':'','error_reson':'',"error_type":'caption'}
#标题是图像的精炼表达。
if CHECK_IMAGE_CAPTION_VLM:
#提示词直接匹配? 图像理解再匹配?
qwen_match_response=qwen_vl_infer(item.illustration_url,'你是一个图文匹配判断专家。',VLM_Match_User_Prompt.replace("{{user_text}}",item.caption_text))
#logger.info(f'qwen_match_response={qwen_match_response}')
#日志存储
logs_info['caption_check']=qwen_match_response
if '匹配结果:' in qwen_match_response and '原因:' in qwen_match_response:
response_info=qwen_match_response.split('匹配结果:')[1]
is_match = response_info.split('原因:')[0].strip().replace('{','').replace('}','')
error_reson=response_info.split('原因:')[1].strip().replace('{','').replace('}','')
if is_match=='不匹配':
error_dict['context_text']=item.caption_text
error_dict['error_reson']=error_reson
content['result'].append(error_dict)
logger.info(f'in caption check=qwen_match_response={qwen_match_response}\n\timage_url={item.illustration_url}\n\t title_text={item.caption_text}\n\t context_info={item.context_info}')
else:
#私有模型方案,e.g.VIT BLIP.......
pass
# 当图像标题和正文内容相同
if item.context_info==item.caption_text:
content['code'] = 200
content['message'] = 'OK'
content['model'] = 'qwen'
return JSONResponse(content=content)
### 二、图像和正文匹配
#(1)提取图序和图标题内容; (2)拿到局部上下文; (3)图序检索和模糊匹配;(4)检测判断;
if CHECK_IMAGE_CONTEXT and len(item.context_info)>5 and len(item.caption_text)>3:
error_dict={'context_text':'','error_reson':'',"error_type":'context'}
#(1)提取图序
image_id=''
if len(item.caption_text)>3:
#提取图序
# 匹配图序:图 K、图K-1、图K-1-2 等,去掉空格
image_id_match = re.findall(r"图\s*\d+(?:-\d+)*", item['caption_text'])
# 去掉空格
if image_id_match:
image_id = image_id_match[0].replace(" ", "") # 输出: 图5 或 图5-7
else:
image_id = ''
logger.info(f'{item.caption_text} image_order is {image_id}')
paragraph_info=''
if paragraph_info=='':
#(2)拿到局部上下文,进行搜索(精准匹配re)
if len(image_id)>0 or (len(item.caption_text)>3 and item.caption_text in item.context_info and len(item.context_info)>5) :
if image_id in item.context_info:
# 只匹配从 image_id 开始到第一个 '\n' 之前的所有字符
pattern_context = rf"{re.escape(image_id)}[^\n]*"
search_context = re.search(pattern_context, item.context_info)
if search_context:
paragraph_info+= search_context.group(0)+'\n'
logger.info(f'{item.context_info} image_order context is {paragraph_info}')
#做图像内容局部匹配
#使用图像标题 在 正文内检索局部相关内容
pattern = rf'[^。]*{re.escape(item.caption_text)}[^。]*。'
match = re.search(pattern, item.context_info)
if match:
paragraph_info+=match.group(0)+'\n'
if len(paragraph_info)>5:
logs_info['match_context_info']=paragraph_info
#标题是图像的精炼表达。
if CHECK_IMAGE_CONTEXT_VLM:
#提示词直接匹配? 图像理解再匹配?
qwen_match_response=qwen_vl_infer(item.illustration_url,'你是一个图文匹配判断专家。',VLM_Match_Context_User_Prompt.replace("{{user_text}}",paragraph_info).replace("{{caption}}",item.caption_text))
#日志存储
logs_info['context_check']=qwen_match_response
#logger.info(f'qwen_match_response={qwen_match_response}')
if '匹配结果:' in qwen_match_response and '原因:' in qwen_match_response:
response_info=qwen_match_response.split('匹配结果:')[1]
is_match = response_info.split('原因:')[0].strip().replace('{','').replace('}','')
error_reson= response_info.split('原因:')[1].strip().replace('{','').replace('}','')
# logger.info(f'is_match={is_match}',error_reson={error_reson})
if is_match=='不匹配':
#error_dict['full_context_text']
error_dict['context_text']=item.context_info
error_dict['error_reson']=error_reson
content['result'].append(error_dict)
logger.info(f'in_context_check,have caption,=qwen_match_response={qwen_match_response}\n\timage_url={item.illustration_url}\n\t title_text={item.caption_text}\n\t context_info={item.context_info}')
# else:
logger.info(f'')
else:
error_dict={'context_text':'','error_reson':'',"error_type":'context'}
#没有caption
#标题是图像的精炼表达。
if CHECK_IMAGE_CONTEXT_VLM:
logger.info('context_check')
#提示词直接匹配? 图像理解再匹配?
qwen_match_response=qwen_vl_infer(item.illustration_url,'你是一个图文匹配判断专家。',VLM_Match_Context_User_Prompt.replace("{{user_text}}",item.context_info).replace("{{caption}}",'用户没有提供图像标题,仅需关注正文。'))
#日志存储
logger.info(f'qwen_match_response={qwen_match_response}')
logs_info['context_check']=qwen_match_response
#logger.info(f'qwen_match_response={qwen_match_response}')
if '匹配结果:' in qwen_match_response and '原因:' in qwen_match_response:
response_info=qwen_match_response.split('匹配结果:')[1]
is_match = response_info.split('原因:')[0].strip().replace('{','').replace('}','')
error_reson= response_info.split('原因:')[1].strip().replace('{','').replace('}','')
if is_match =='不匹配':
#error_dict['full_context_text']
error_dict['context_text']=item.context_info
error_dict['error_reson']=error_reson
content['result'].append(error_dict)
logger.info(f'in_context_check,no_caption,=qwen_match_response={qwen_match_response}\n\timage_url={item.illustration_url}\n\t title_text={item.caption_text}\n\t context_info={item.context_info}')
#pass
elif CHECK_IMAGE_CONTEXT:
error_dict={'context_text':'','error_reson':'',"error_type":'context'}
#没有caption
#标题是图像的精炼表达。
if CHECK_IMAGE_CONTEXT_VLM:
logger.info('context_check')
#提示词直接匹配? 图像理解再匹配?
qwen_match_response=qwen_vl_infer(item.illustration_url,'你是一个图文匹配判断专家。',VLM_Match_Context_User_Prompt.replace("{{user_text}}",item.context_info).replace("{{caption}}",'用户没有提供图像标题,仅需关注正文。'))
#日志存储
logger.info(f'qwen_match_response={qwen_match_response}')
logs_info['context_check']=qwen_match_response
#logger.info(f'qwen_match_response={qwen_match_response}')
if '匹配结果:' in qwen_match_response and '原因:' in qwen_match_response:
response_info=qwen_match_response.split('匹配结果:')[1]
is_match = response_info.split('原因:')[0].strip().replace('{','').replace('}','')
error_reson= response_info.split('原因:')[1].strip().replace('{','').replace('}','')
if is_match =='不匹配':
#error_dict['full_context_text']
error_dict['context_text']=item.context_info
error_dict['error_reson']=error_reson
content['result'].append(error_dict)
logger.info(f'in_context_check,no_caption,=qwen_match_response={qwen_match_response}\n\timage_url={item.illustration_url}\n\t title_text={item.caption_text}\n\t context_info={item.context_info}')
#pass
## 存储日志飞轮
save_logs_to_file(logs_info,file_name=save_json_path)
content['code'] = 200
content['message'] = 'OK'
content['model'] = 'qwen'
return JSONResponse(content=content)
except Exception as e:
content['code'] = 500
content['message'] = 'Server Exception'
content['result'] = []
return JSONResponse(content=content)
## 不调用版面检测接口
@app.post("/v1/image_text_matching_local_file")
def formula_parse_by_tex(item:Item):
logger.info('in image_text_matching_local_file')
#logger.info(f'item.illustration_url={item.illustration_url}\nitem.caption_text={item.caption_text}\ncontext_info={item.context_info}')
try:
#内容存储
content = {"result":[]}
#数据飞轮
logs_info={'url':'',
'image_path':'',
'sub_image_path':[],
'markdown_info':'',
'markdown_filter':'',
'formula_check':'',
'formula_check_postProcess':'',
'image_info':'',
'return_results':''
}
if item.illustration_url is not None and len(item.illustration_url.strip())<=0:
logger.error("requests url is None: {0}")
content['code'] = 400
content['message'] = 'the url is empty!'
# content['base64_images'] = []
# content['result'] = {"error_formula":[],"error_reason":[],"corrected_formula":[]}
return JSONResponse(content=content)
if len(item.caption_text)<=0 and len(item.context_info)<=0:
logger.error("caption and context all None: {0}")
content['code'] = 400
content['message'] = 'caption and context all Empty:!'
# content['base64_images'] = []
# content['result'] = {"error_formula":[],"error_reason":[],"corrected_formula":[]}
return JSONResponse(content=content)
# url = item.illustration_url.strip()
# logs_info['url']=url
# # 保存文件 #这里需要判断是否这个url连接的后缀是否包含特定元素
# task_id = get_millisecond_time()
# suffix=None
# apart_file_url= url.split('/')[-1]
# # if '.' in apart_file_url:
# # file_name,suffix = url.split('/')[-1].split('.')
# # else:
# # file_name=apart_file_url
# # if suffix==None or suffix not in ['jpg','png','jpeg']:
# # suffix='jpg'
# # #创建日志文件夹
# # current_date = datetime.now().strftime("%Y-%m-%d")
# # daily_log_dir = os.path.join(DATA_LOGS_DATA, current_date)
# # # 创建当天的日志目录
# # os.makedirs(daily_log_dir, exist_ok=True)
# # #基于每日文件夹创建子文件存储,image,json,sub_img
# # json_save_dir=os.path.join(daily_log_dir,'jsons')
# # os.makedirs(json_save_dir, exist_ok=True)
# # image_save_dir=os.path.join(daily_log_dir,'images')
# # os.makedirs(image_save_dir, exist_ok=True)
# # #文件名存储
# # file_name=file_name+'_'+task_id
# # save_json_path=os.path.join(json_save_dir, file_name+'.json')
# # #每毫秒一个文件名,防止数据覆盖
# # save_file_path= os.path.join(image_save_dir, file_name+'.'+suffix)
# # logs_info['image_path']=save_file_path
# # #创建存储的文件,只会生效一次
# # start = time.time()
# # # 通过url下载文件
# # proxies = {"http": None, "https": None} # 绕过代理
# # response = requests.get(url, proxies=proxies, stream=True)
# # if response.status_code != 200:
# # logger.error("下载文件失败: {0}", url)
# # content['code'] = 400
# # content['message'] = f"下载文件失败: {url}"
# # content["result"] = []
# # return JSONResponse(content=content)
# # ## 图像数据存储
# # with open(save_file_path, 'wb') as f:
# # for chunk in response.iter_content(chunk_size=10240):
# # if chunk:
# # f.write(chunk)
# # logger.info(f'图像存储成功,path={save_file_path}')
end = time.time()
# cost = end - start
# logger.info("下载文件:\t{0}\t时间为:{1}秒", url, cost)
### 图像和标题匹配
if CHECK_IMAGE_CAPTION:
error_dict={'context_text':'','error_reson':'',"error_type":'caption'}
#标题是图像的精炼表达。
if CHECK_IMAGE_CAPTION_VLM:
#提示词直接匹配? 图像理解再匹配?
qwen_match_response=qwen_vl_infer(item.illustration_url,'你是一个图文匹配判断专家。',VLM_Match_User_Prompt.replace("{{user_text}}",item.caption_text))
logger.info(f'qwen_match_response={qwen_match_response}')
if '匹配结果:' in qwen_match_response and '原因:' in qwen_match_response:
response_info=qwen_match_response.split('匹配结果:')[1]
is_match = response_info.split('原因:')[0].strip()
error_reson=response_info.split('原因:')[1].strip()
if is_match=='不匹配':
error_dict['context_text']=item.caption_text
error_dict['error_reson']=error_reson
content['result'].append(error_dict)
else:
#私有模型方案,e.g.VIT BLIP.......
pass
### 图像和正文匹配
content['code'] = 200
content['message'] = 'OK'
content['model'] = 'qwen'
return JSONResponse(content=content)
except Exception as e:
content['code'] = 500
content['message'] = 'Server Exception'
content['result'] = []
return JSONResponse(content=content)
if __name__=="__main__":
logger.info("image text matching check start successful!")
uvicorn.run(app=f'main:app', host="0.0.0.0", port=29500, reload=False, workers=10) # 提供后端的接口 http://61.170.32.13:29500/v1/image_text_matching
# uvicorn.run(app=f'main:app', host="0.0.0.0", port=29501, reload=False, workers=3) # 个人测试测试
# cd /nfs/liuxin/work/Image_TextTitle_Matching
# conda activate text_check
# nohup python -u main.py > main_image_title.log 2>&1 &
# tail -f main_image_title.log
# ss -ntlp | grep 29500
import os
# 通过 pip install volcengine-python-sdk[ark] 安装方舟SDK
from volcenginesdkarkruntime import Ark
# 替换 <MODEL> 为模型的Model ID
model="doubao-1.5-vision-pro"
# 初始化Ark客户端,从环境变量中读取您的API Key
client = Ark(
api_key="fc61954e-f585-4aac-b88a-7baf56e05d9e",
)
# 创建一个对话请求
response = client.chat.completions.create(
# 指定您部署了视觉理解大模型的推理接入点ID
model = model,
messages = [
{
# 指定消息的角色为用户
"role": "user",
"content": [
# 图片信息,希望模型理解的图片
{"type": "image_url", "image_url": {"url": "https://ark-project.tos-cn-beijing.volces.com/doc_image/ark_demo_img_1.png"},},
# 文本消息,希望模型根据图片信息回答的问题
{"type": "text", "text": "支持输入是图片的模型系列是哪个?"},
],
}
],
)
print(response)
print(response.choices[0].message.content)
\ No newline at end of file
import re
def remove_blank_lines(text):
# 使用正则表达式去除空行
result = re.sub(r'\n\s*\n', '\n', text)
return result
if __name__ == "__main__":
sub_text_infos="""# 第一章 机械运动
表 1-1 地球上不同纬度的重力加速度
| 地点 | 赤道 | 广州 | 上海 | 北京 | 北极 |
| --- | --- | --- | --- | --- | --- |
| 纬度 | 0° | 23°06′ | 30°12′ | 39°56′ | 90° |
| $g$ (m/s²) | 9.780 | 9.788 | 9.794 | 9.801 | 9.832 |
"""
result=remove_blank_lines(sub_text_infos)
print(result)
import os
import sys
import re
from loguru import logger
import base64
#from config.config import PROMPT
from openai import OpenAI
import time
import requests
import json
from datetime import datetime
import cv2
import ast
from config.config import LAYOUT_CHECK_URL,FORMULA_DETECTION_URL,LAYOUT_CHECK_URL_19
from config.config import Official_API_KEY
from config.config import Official_OPENAI_URL,ENV
from config.config import QWEN_URL,QWEN_API_KEY
from utils.common import get_millisecond_time
from tasks.mysql_utils import DBUtils
#class Formula_Checker(metaclass=Singleton):
class Formula_Checker():
def __init__(self):
#self.GPT = GPTModel()
#self.PDFUTILS = PDFUtils()
#self.PDFTOOLS = PDFTools()
#self.chapter_prompt_task_id = 8000000
#self.chapter_prompt_task_id_again = 8000001
pass
def perform_re_check(self,ocr_result):
"""
检查OCR结果中是否包含公式。
参数:
ocr_result (dict): 包含OCR识别结果的字典,格式为:
{
'errorCode': int,
'msg': str,
'data': list[str]
}
返回:
bool: 如果包含公式则返回True,否则返回False。
"""
# 正则表达式用于匹配数学公式中的常见符号和结构,避免简单变量和表达式
# formula_pattern = re.compile(
# r"([A-Za-z]+\s*=\s*[A-Za-z0-9+\-*/^()]+(?:\s*[+\-*/^]+\s*[A-Za-filter_correct_formulas_elem-9+\-*/^()]+)+)|" # 复杂公式,包含多个运算符
# r"(\b√[A-Za-z0-9]+\b)|" # 根号
# r"(\bΔ\b)|" # Δ
# r"(\([A-Za-z0-9+\-*/^()]+\)\s*[+\-*/^]\s*\([A-Za-z0-9+\-*/^()]+\))|"
# r"([A-Za-z]*\d*[(x)(y)(z)(a)(b)(c)]*\s*[+\-*/^]+\s*\(?[A-Za-z0-9+\-*/^()]+\)?)" # 公式中的符号运算
# )
formula_pattern = re.compile(
r"([A-Za-z]+\s*=\s*[A-Za-z0-9+\-*/^()]+)|" # 一般公式,包含多个运算符或字母
r"(\b√[A-Za-z0-9]+\b)|" # 根号
r"(\bΔ\b)|" # Δ
r"(\([A-Za-z0-9+\-*/^()]+\)\s*[+\-*/^]\s*\([A-Za-z0-9+\-*/^()]+\))|" # 复杂括号表达式
r"([A-Za-z]*\d*[(x)(y)(z)(a)(b)(c)]*\s*[+\-*/^]+\s*\(?[A-Za-z0-9+\-*/^()]+\)?)" # 公式中的符号运算
)
if ocr_result['errorCode'] != 0:
logger.error(f"OCR识别失败,错误信息: {ocr_result['msg']}")
return False
# 将OCR结果中的文本内容合并成一个字符串
combined_text = ' '.join(ocr_result['data'])
# 检查合并后的文本中是否包含公式
if formula_pattern.search(combined_text):
logger.info("发现公式")
return True
logger.info("OCR结果中不包含公式。")
return False
def perform_ocr(self,file_path):
"""
进行OCR识别的函数。
参数:
file_path (str): 图片文件的路径
返回:
dict: OCR识别结果的JSON响应
"""
# url = "https://dcg-ai-red-list.5rs.me/v1/dcg_ocr"
url = OCR_URL
params = {
"userid": "yxI_110",
"client_id": "dcg-red-list"
}
headers = {"Authorization": "Bearer dcg-MTQ2MDRkYWRmNzRjMDg0ZjZmNTc3YTliMWM0YzYwYmVlZDE="}
timeout_duration = 5
proxies = {"http": None, "https": None}
try:
logger.info("开始进行OCR请求")
with open(file_path, "rb") as file:
res = requests.post(url, files={"file": file}, data=params, headers=headers,timeout=timeout_duration, proxies=proxies)
if res.status_code == 200:
logger.info("OCR请求成功")
return res.json()
else:
logger.error(f"OCR请求失败,状态码: {res.status_code}")
return {"error": f"Request failed with status code {res.status_code}"}
except Exception as e:
logger.exception("OCR请求过程中出现异常")
return {"error": str(e)}
def perform_ocr_all(self,file_path):
"""
进行OCR识别的函数。
参数:
file_path (str): 图片文件的路径
返回:
dict: OCR识别结果的JSON响应
"""
# url = "https://dcg-ai-red-list.5rs.me/v1/dcg_ocr"
url = OCR_URL
params = {
"userid": "yxI_110",
"client_id": "dcg-red-list",
"show_details":"True"
}
headers = {"Authorization": "Bearer dcg-MTQ2MDRkYWRmNzRjMDg0ZjZmNTc3YTliMWM0YzYwYmVlZDE="}
timeout_duration = 3
proxies = {"http": None, "https": None}
#response = requests.post(url, data=json_data, headers=headers, timeout=600, proxies=proxies)
try:
logger.info("开始进行OCR请求")
with open(file_path, "rb") as file:
res = requests.post(url, files={"file": file}, data=params, headers=headers,timeout=timeout_duration, proxies=proxies)
if res.status_code == 200:
logger.info("OCR请求成功")
return res.json()
else:
logger.error(f"OCR请求失败,状态码: {res.status_code}")
return {"error": f"Request failed with status code {res.status_code}"}
except Exception as e:
logger.exception("OCR请求过程中出现异常")
return {"error": str(e)}
def filter_correct_formulas(self,data):
filtered_result = []
logger.info('in filter_correct_formulas ,data={}'.format(data))
for entry in data:
error_formulas = entry['error_formula']
corrected_formulas = entry['corrected_formula']
error_reasons = entry['error_reason']
# 创建新的条目字典,初始化为空列表
new_entry = {
'error_formula': [],
'error_reason': [],
'corrected_formula': []
}
for i in range(len(error_formulas)):
if error_formulas[i] != corrected_formulas[i]:
new_entry['error_formula'].append(error_formulas[i])
new_entry['error_reason'].append(error_reasons[i])
new_entry['corrected_formula'].append(corrected_formulas[i])
# 只有在新条目不为空时才添加到过滤结果中
if new_entry['error_formula']:
filtered_result.append(new_entry)
return filtered_result
# def filter_correct_formulas_elem(self,data):
# filtered_result = {
# 'error_formula': [],
# 'error_reason': [],
# 'corrected_formula': []
# }
# error_formulas = data['error_formula']
# corrected_formulas = data['corrected_formula']
# error_reasons = data['error_reason']
# filter_list=['与标准方程的定义不符','公式未完整显示','公式顺序错误','计算错误','公式未完成','公式不完整','分母部分不正确','分子部分不准确','乘法的形式','坐标错误','公式缺少负号','速度可以为负值','总电阻的倒数等于各电阻倒数之和','分母应为']
# for i in range(len(error_formulas)):
# # if "=" not in error_formulas[i]:
# # continue
# # temp_error=error_formulas[i]
# # temp_corrected=corrected_formulas[i]
# if error_formulas[i] != corrected_formulas[i] and error_reasons[i] not in filter_list and '\frac{\Delta \Phi}{R}' not in error_formulas[i]:
# filtered_result['error_formula'].append(error_formulas[i])
# filtered_result['error_reason'].append(error_reasons[i])
# filtered_result['corrected_formula'].append(corrected_formulas[i])
# return filtered_result
def check_numbers_in_string(self,text):
"""
检查字符串中的数字,如果包含大于或等于5的数字返回False,否则返回True。
:param text: 输入的字符串
:return: 布尔值
"""
# 提取字符串中的所有数字,包括中间可能有空格的情况
numbers = re.findall(r'\d(?:\s*\d)*', text) # 匹配数字,允许中间有空格
#print("提取的原始数字:", numbers)
# 去除空格并转换为整数类型
int_numbers = [int(num.replace(' ', '')) for num in numbers]
#print("转换为整数后的数字:", int_numbers)
# 检查是否存在大于或等于5的数字
for num in int_numbers:
if num >= 5:
return False
return True
def filter_correct_formulas_elem(self, data):
filtered_result = {
'error_formula': [],
'error_reason': [],
'corrected_formula': []
}
error_formulas = data['error_formula']
corrected_formulas = data['corrected_formula']
error_reasons = data['error_reason']
filter_list = [
'无意义',
'不等号的使用不当',
'上下文中',
'公式中不应有系数2',
'不需要正负号',
'未定义',
'不规范的变量表示',
'公式不清晰',
'使用了小写字母b',
'不规范的表达式',
'导致公式不规范',
'不符合标准的向量表示法',
'拼写错误',
'符号错误',
'向量的表示',
'错误的向量表示',
'错误的符号和表达',
'向量符号书写错误',
'向量的表示应为',
'单位书写错误',
'数字书写错误',
'变量顺序错误',
'公式中使用了余弦而不是点积',
'平均速度与平均速率',
'单位描述不一致',
'无实际意义',
'取值不准确',
'等式不正确',
'计算结果',
'缺少计算结果',
'公式中缺少负号',
'公式中缺少乘号',
'公式与余弦定理不符',#余弦定理的推导式子
'三角形内角的正弦平方和不等于另一个角的正弦平方',
'等号使用错误',
'公式缺少一个比值',
'不完整的三角形符号',
'角度的正弦函数',
'根号符号不正确',
'符号表示错误',
'缺少单位向量',
'逗号',
'值错误',
'等号两边不相等',
'函数值错误',
'应用错误',
'包含中文逗号',
'样本空间定义不一致',
'多余的变量e',
'推导过程', #新数据
'公式等号前后不相等',
'公式与平行四边形的定义不符',
'公式中缺少乘法符号',
'缺少联合符号',
'矢量表示错误,缺少方向单位向量',
'方程组的解与原方程不匹配',
##过滤目前新的数据的问题
'等式两边不相等',
'乘积而不是相乘',
'公式中变量范围的表示方式不正确',
'公式描述不准确',
'不等式错误地变成了等式',
'缺少等号',
r"计算错误,30 \times 30 \times 20 不等于 1800",
r"x^2 = 18000,\sqrt{18000} 不等于 \sqrt{1800}",
r"x \geq 0 与 x = 30 \sqrt{2} 不一致",#离谱错误,初中数学
'不等式的解法错误',
'公式中条件部分应与公式分开',
'公式描述与法拉第电磁感应定律不符,缺少负号。',
'电阻分配和电流成反比',
'不等式的条件错误',
'不等式错误地转化为等式',
'公式中平方根的表达式应为绝对值',
'不等式条件下的等式错误',
'公式前后不一致',
'公式缺少描述',
'化简绝对值时符号错误',
'不符合前面的计算结果',
'化简绝对值时符号错误',
'公式中平方根的简化不正确', #初中数学
'展开错误','公式不完整','矢量加法和减法错误',
'公式不完整',
'公式中重复了等号',
'公式中的逻辑错误',
##未知文稿的错误
'位移符号错误',#应用物理的句号。被认为下标
'不够精确',
'角速度计算错误导致线速度计算错误', '使用了近似值3.14','单位换算错误',
'化简绝对值时符号错误', '公式中平方根的展开不正确','不正确的解集','平方根计算错误','绝对值函数展开错误',
'公式中未进行化简', '分解错误','化简错误','公式不完整','公式重复','错误的化简步骤','不符合平方根的正负值','平方根计算错误','不等式条件下的等式不成立','平方根的定义不正确','公式中平方根的表达式应为绝对值',
'展开错误','进行化简','化简绝对值时符号错误','因式分解错误','公式推导错误', '不是一个完全平方数的分解形式','不是一个正整数','公式重复','公式推导错误', '公式未完整显示', '公式顺序错误', '计算错误', '公式未完成',
'公式不完整', '分母部分不正确', '分子部分不准确', '乘法的形式', '坐标错误', '公式缺少负号',
'速度可以为负值', '总电阻的倒数等于各电阻倒数之和', '分母应为','灯','电流符号','重复'
]
filter_latex=[
'3x \\geq 0 \\text{且}-x \\geq 0',
"\\frac{\\sqrt{2Rh_{1}}}{\\sqrt{2Rh_{2}}}",
"r = \\sqrt{2Rh}",
"\\sqrt { \\frac { x - 5 } { 7 - x } } = \\frac { \\sqrt { x - 5 } } { \\sqrt { 7 - x } }",
"\\frac { \\sqrt { x + 1 } } { \\sqrt { x - 1 } } = \\sqrt { \\frac { x + 1 } { x - 1 } }",
'x = \\pm \\sqrt{1800}',
'\\frac { \\sqrt { 2 R h _ { 1 } } } { \\sqrt { 2 R h _ { 2 } } }',
'r = \\sqrt { 2 R h }',
'2-a≥0且a+1≠0',
'3x≥0且-x≥0',
'\\sqrt { 3 x } + \\sqrt { - x }',
'b + \\frac { 1 } { 2 } =0',
'a-2=0',
'R _ { 2 } = \\rho \\frac { 2 l } { S _ { 2 } }',
'x = \\pm \\sqrt{1800}.',
'\\sqrt{(-36) \\times 16 \\times (-9)}',
r'm g = q v B + q \\frac { E } { d }',
r'Q = \\overline { I } \\cdot \\Delta t = \\frac { \\overline { E } } { R } \\cdot \\Delta t = \\frac { \\Delta \\Phi } { R }',
r'P _ { 热 } = P R',
r"a = \frac { v ^ { 2 } } { r }",
r"\sqrt { \\frac { x - 5 } { 7 - x } } = \\frac { \\sqrt { x - 5 } } { \\sqrt { 7 - x } }",
r"P_{出max}=\frac{E^2}{4r}"
"P_{出max}=\\frac{E^2}{4r}",
r"\eta = \frac{U}{E} = \frac{R}{R+r}",
"\\eta = \\frac{U}{E} = \\frac{R}{R+r}",
r"P_{出max}=\\frac{E^2}{4r}",
r"\\sqrt{(a)^2}=a(a\\geq0)",
r"\sqrt{(a)^2}=a(a\geq0)",
'\\omega = 2\\pi n',r'\omega = 2\pi n',
'r"\eta = \frac{U}{E} = \frac{R}{R + r}"',
r"\cos \angle OF_1B_1 = e",
r"\omega = 2 \pi n",
r"\Delta \theta = \frac{\omega^2 - \omega_0^2}{2\beta}",
r"\sqrt{(y_1 + y_2)^2 - 4y_1 y_2}",
r"|AB| = \sqrt{1 + \frac{1}{k^2}} |y_1 - y_2|",
r"\left( \frac{p}{2}, 0 \right)", r"x = -\frac{p}{2}", r"\left( -\frac{p}{2}, 0 \right)", r"x = \frac{p}{2}",
##新一轮过滤
r"g = \frac{2h}{t^2}",
r"E=\frac{\Delta \Phi}{\Delta t}",
r"E = -\frac{\Delta \Phi}{\Delta t}",
r"Q=I \cdot \Delta t=\frac{E}{R} \cdot \Delta t=\frac{\Delta \Phi}{R}",
r"E=\frac{\Delta \Phi}{\Delta t}",
#新优化错误
r"P(A) = P(AB \\cup AB) = P(AB) + P(AB)", r"P(AB) = P(A) - P(A)P(B)", r"P(ABC) = P(A)P(B)P(C)",
r"P(A) = P(AB \cup AB) = P(AB) + P(AB)",
r"P(AB) = P(A)P(B)",
r"P(A) = P(AB \cup AB) = P(AB) + P(AB)",
r"|AB| = \\sqrt{1 + \\frac{1}{k^2}} |y_1 - y_2|",
r"c^2 = b^2 + a^2",
##新数据
'|B_1F_1| = a',
'\\sqrt{\\frac{a}{b}} = \\frac{\\sqrt{a}}{\\sqrt{b}} (a \\ge 0, b > 0)',
r"\sqrt{\frac{a}{b}} = \frac{\sqrt{a}}{\sqrt{b}} (a \ge 0, b > 0)",
r'\sqrt{a} \cdot \sqrt{b} = \sqrt{ab} (a \geq 0, b \geq 0)',
r'\sqrt{ab} = \sqrt{a} \cdot \sqrt{b} (a \geq 0, b \geq 0)',
'\\sqrt{a} \\cdot \\sqrt{b} = \\sqrt{ab} (a \\geq 0, b \\geq 0)',
'\\sqrt{ab} = \\sqrt{a} \\cdot \\sqrt{b} (a \\geq 0, b \\geq 0)',
r"y = \frac{p}{2}",
r"y = \\frac{p}{2}",
'k_{AB} = -\\frac{b^2 x_0}{a^2 y_0}'
r'\sqrt{(x_1 + x_2)^2 - 4x_1 x_2}',
'\\sqrt{(x_1 + x_2)^2 - 4x_1 x_2}',
r"$|B_1F_1\right| = a$",
r"|B_1F_1\right| = a",
r"\left|B_1F_1\right| = a",
r"\overrightarrow{CA} = \overrightarrow{OA} - \overrightarrow{OC}",
r"\sqrt{\frac{a}{b}} = \frac{\sqrt{a}}{\sqrt{b}} (a \geq 0, b > 0)",
r"30 \times 30 \times 20 = x^2 \times 10,得 x^2 = 1800,",
r"\therefore x = \pm \sqrt{1800}。",
r"\therefore x \geq 0,x = 30 \sqrt{2}。",#离谱三个初中数学
r"\sqrt{a^2} = a (a \geq 0)",
'\\sqrt{a-2} \\geq 0, \\sqrt{b+\\frac{1}{2}} \\geq 0', #初中数学
'\\sqrt{a-2} \\geq 0, \\sqrt{b+\\frac{1}{2}} \\geq 0',#初中数学
r"n = \frac{\omega}{2\pi}",#自动补全
r"\Delta \theta = \frac{\omega^2 - \omega_0^2}{2\beta}",
r"H_0 = v_0 \sin \theta t - \frac{1}{2}gt^2",
r"H_0 = v_0 \sin \theta \frac{v_0 \sin \theta}{g} - \frac{1}{2}g \left( \frac{v_0 \sin \theta}{g} \right)^2",
r'\frac{E^2}{(R-r)^2+4r}',
'\\frac{E^2}{(R-r)^2+4r}',
r'P=IU',
r"\frac{U_1}{U_2} = \frac{R_1}{R_2}",
r"\dfrac{(n_1+n_2)e}{t}",
r"\{M ||MF| = d\}",
r"\left|AB\right| = \sqrt{1 + \left( \frac{1}{k} \right)^2} \left|y_1 - y_2\right| = \sqrt{1 + \left( \frac{1}{k} \right)^2} \sqrt{(y_1 + y_2)^2 - 4y_1 y_2}",
r"E=n\frac{\Delta \Phi}{\Delta t}=L\frac{\Delta I}{\Delta t}",
r"\\frac{E^2R}{(R+r)^2}",
r"\\frac{E^2}{(R-r)^2+4r}",
r"\frac{E^2R}{(R+r)^2}",
r"\frac{E^2}{(R-r)^2+4r}",
r"\cos \angle OF_1B_1 = e",
r"e = \frac{c}{a}",
r"c^2=b^2+a^2",r"\frac{x^2}{a^2}+\frac{y^2}{b^2}=1",
r"\overline{v}——物体在\Delta t时间内的平均速度",
r"\sqrt{(x-2)^2} - \sqrt{(1-2x)^2} = (x-2) + (1-2x) = -x-1",
r"$$|x-2|+\sqrt{(x+3)^2}+\sqrt{x^2-10x+25}$$",
r"\frac{\sqrt{2Rh_1}}{\sqrt{2Rh_2}}",
r"$t=\sqrt{\frac{h}{5}}$",
r"$h = 5t^2$",
r"h=5t^2",
r"$|x-2|+\sqrt{(x+3)^2}+\sqrt{x^2-10x+25}$",
r"$$|x-2|+\sqrt{(x+3)^2}+\sqrt{x^2-10x+25}$$",
r"\frac{\sqrt{2Rh_1}}{\sqrt{2Rh_2}}",
r"\overrightarrow{OM_1}",
r'\Delta v_{\text{飞}} = 300 - 0 = 300(km/s) = 83(m/s)',
'\\Delta v_{\\text{飞}} = 300 - 0 = 300(km/s) = 83(m/s)',
r"\Delta v_{\text{飞}} = 300 - 0 = 300(km/s) = 83(m/s)",
r"|AB| = \sqrt{1 + k^2} |x_1 - x_2| = \sqrt{1 + k^2} \sqrt{(x_1 + x_2)^2 - 4x_1x_2}",
r"\left|AB\right| = \sqrt{1 + \left(\frac{1}{k}\right)^2} \left|y_1 - y_2\right| = \sqrt{1 + \left(\frac{1}{k}\right)^2} \sqrt{(y_1 + y_2)^2 - 4y_1y_2}",
r"h = 5t^2",r"\frac{\Delta \Phi}{R}",r"\omega = 2 \pi n",
r"\varphi = \alpha = 2\theta = \omega t",r"Q=I \cdot \Delta t=\frac{\overline{E}}{R} \cdot \Delta t=\frac{\Delta \Phi}{R}",
r"n \frac{\Delta \Phi}{\Delta t}",r"\eta=\frac{U}{E}=\frac{R}{R+r}"]
filter_correct_formula=[
r"B=\\frac{F}{I \\cdot L}",
r"F=Bqv \sin \theta",
r"\eta = \frac{U}{E} = \frac{R}{R+r}",
r"\\eta = \\frac{U}{E} = \\frac{R}{R+r}",
r"I = \\frac{Q}{t}",
r"|AB| = \\sqrt{1 + k^2} |y_1 - y_2|",
'|B_1F_1| = b',
r'E=-\frac{\Delta \Phi}{\Delta t}',
'E=-\\frac{\\Delta \\Phi}{\\Delta t}',
r"y^2 = -2px \ (p < 0)",
r"\left( \frac{p}{2}, 0 \right)",
r"x = -\frac{p}{2}",
r"\left( 0, -\frac{p}{2} \right)",
r"y = -\frac{p}{2}",
r"y^2 = -2px \ (p < 0)",
r"\left( \frac{p}{2}, 0 \right)",
r"x = -\frac{p}{2}",
r"x^2 = -2py \ (p < 0)",
r"\left( 0, \frac{p}{2} \right)",
r"y = -\frac{p}{2}",
r"y^2 = 2px \ (p > 0)",
r"\left( \frac{p}{2}, 0 \right)",
r"x = -\frac{p}{2}",
r"x^2 = 2py \ (p > 0)",
r"\left( 0, \frac{p}{2} \right)",
r"y = -\frac{p}{2}",
'k_{AB} = -\\frac{b^2}{a^2} \\cdot \\frac{x_0}{y_0}'
]
# for i in range(len(error_formulas)):
# if (
# error_formulas[i] != corrected_formulas[i] and
# all(item not in error_reasons[i] for item in filter_list) and
# '\frac{\Delta \Phi}{R}' not in error_formulas[i] and '=' in error_formulas[i]
# ):
# filtered_result['error_formula'].append(error_formulas[i])
# filtered_result['error_reason'].append(error_reasons[i])
# filtered_result['corrected_formula'].append(corrected_formulas[i])
#修正前
# for i in range(len(error_formulas)):
# if (
# error_formulas[i] != corrected_formulas[i] and
# all(item not in error_reasons[i] for item in filter_list) and
# '\frac{\Delta \Phi}{R}' not in error_formulas[i] and
# all(item not in error_formulas[i] for item in filter_latex) and
# all(item not in corrected_formulas[i] for item in filter_correct_formula)
# ):
# filtered_result['error_formula'].append(error_formulas[i])
# filtered_result['error_reason'].append(error_reasons[i])
# filtered_result['corrected_formula'].append(corrected_formulas[i])
#重新修正规则
for i in range(len(error_formulas)):
if (
self.check_numbers_in_string(corrected_formulas[i]) and
len(error_formulas[i])!=0 and
all(item not in error_reasons[i] for item in filter_list) and
'\frac{\Delta \Phi}{R}' not in error_formulas[i] and
all(item not in error_formulas[i] for item in filter_latex) and
all((item not in corrected_formulas[i]) for item in filter_correct_formula)
):
filtered_result['error_formula'].append(error_formulas[i])
filtered_result['error_reason'].append(error_reasons[i])
filtered_result['corrected_formula'].append(corrected_formulas[i])
return filtered_result
def filter_text_answer(self,text):
# 要过滤的关键词
keywords = ['证明','证明:','证明:','解','解:','解:','例题', '解析:','名师经验谈:','答案','解析','名师经验谈','答案:','解析:','名师经验谈:','答案:','各师经验谈:']
# 正则表达式匹配标题及其内容
pattern = re.compile(r'(#{1,6} .+?)(?=\n#{1,6} |\Z)', re.S | re.M)
# 查找所有段落
matches = pattern.findall(text)
# 过滤段落,仅对标题进行检查
filtered_paragraphs = []
last_end = 0
for match in matches:
start = text.find(match, last_end)
end = start + len(match)
# 提取标题内容中的文字部分
title = re.sub(r'^[#\d、. ]+', '', match.split('\n')[0]).strip()
if not any(title == keyword for keyword in keywords):
filtered_paragraphs.append(text[last_end:start])
filtered_paragraphs.append(match)
last_end = end
# 添加最后一部分内容
filtered_paragraphs.append(text[last_end:])
# 合并过滤后的段落
filtered_text = ''.join(filtered_paragraphs)
return filtered_text
def formula_checker(self,image_path,prompt):
if image_path is None or len(image_path)<=0:
raise EnvironmentError("image_path Errors!")
# Function to encode the image
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
# Getting the base64 string
base64_image = encode_image(image_path)
logger.info('api-key和endPoints 没有做封装,生产环境注意!')
api_key = Official_API_KEY
end_point = Official_OPENAI_URL
timeout = 30
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
message = [
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
}
}
]
}
]
#message = [system_prompt, user_prompt]
payload = {
"model": "gpt-4o",
"messages": message,
"max_tokens": 1000,
"temperature": 0.0
}
response = requests.post(url=end_point, headers=headers, json=payload, timeout=timeout)
response_json = response.json()
logger.info('response_json={}'.format(response_json))
if response.status_code == 200:
if 'choices' in response_json and len(response_json['choices']) > 0:
model_reply = response_json['choices'][0]['message']['content']
#logger.info('gpt return infos={}'.format(model_reply))
return model_reply,base64_image
else:
logger.info('model_reply is Empty!')
return ''
else:
print("请求服务器错误")
return ''
def write_token_logs(self,params):
start_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
model_status = params['status']
url = params['url']
api_key=params['api-key']
input_token = params['input_token']
output_token = params['output_token']
model = params['model']
end_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
input_text=params['text']
DBUtils.insert_llm_log(env=ENV, source_type="TEXT", server_name="formula_corrector",
request=input_text[:50], model=model, url=url,
api_key=api_key, input_token=input_token,
output_token=output_token, start_time=start_time,
end_time=end_time, message=model_status)
return 1
#{'error': {'message': 'You exceeded your current requests list.', 'type': 'limit_requests', 'param': None, 'code': 'limit_requests'}, 'request_id': '823d3e60-9d84-99b0-9a48-3d07875ff214'}
def qwen_official_infer(self,system_prompt,user_prompt):
request_flag = 0
requests_list_model=['qwen-max-latest','qwen-max']
try:
client = OpenAI(
api_key=QWEN_API_KEY,
base_url=QWEN_URL,
)
start_time = time.time()
start_time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(start_time))
for model_name in requests_list_model:
completion = client.chat.completions.create(
model=model_name,
messages=[
{'role': 'system', 'content': system_prompt},
{'role': 'user', 'content': user_prompt}
],
temperature=0.1,
seed=42,
top_p=0.8
)
res_json = completion.model_dump_json()
res_json = json.loads(res_json)
# 验证是否请求成功!
if 'error' not in res_json:
logger.info(f'qwen-max infer successfully, response: {res_json}, time: {time.time() - start_time}s')
request_flag=1
break
return_params = {
'status': 'Failed',
'model': 'qwen-max-latest',
'text': user_prompt[:3]
}
try:
if 'choices' in res_json and len(res_json['choices']) > 0:
content = res_json['choices'][0]['message']['content']
return_params['input_token'] = res_json['usage']['prompt_tokens']
return_params['output_token'] = res_json['usage']['completion_tokens']
return_params['status'] = 'OK'
# Logging the token usage
end_time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
DBUtils.insert_llm_log(
env='PRO',
source_type="TEXT",
server_name="qwen_official_infer",
request=user_prompt[:10],
model=return_params['model'],
url="formula_correction",
api_key="formula_correction",
input_token=return_params['input_token'],
output_token=return_params['output_token'],
start_time=start_time_str,
end_time=end_time_str,
message=return_params['status']
)
return content
else:
logger.error("No valid choices found in response.")
except Exception as e:
logger.error(f'qwen_official_infer error, error: {str(e)}, response: {res_json}')
except Exception as e:
logger.error(f'qwen_official_infer failed, error: {str(e)}')
return ''
def formula_checker_mini(self,image_path,prompt,text,bool_img=True):
return_params={}
if image_path is None or len(image_path)<=0:
raise EnvironmentError("image_path Errors!")
# Function to encode the image
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
# Getting the base64 string
logger.info('api-key和endPoints 没有做封装,生产环境注意!')
api_key = Official_API_KEY
end_point = Official_OPENAI_URL
timeout = 180
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
if bool_img:
base64_image = encode_image(image_path)
message = [
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}",
"detail": "auto"
}
}
]
}
]
else:
message = [
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt
}
]
},
]
# message = [
# {
# "role": "user",
# "content": {
# "type": "text",
# "text": prompt
# },
# },
# {
# "role": "system",
# "content": {
# "type": "text",
# "text": '你是一个公式过滤专家,能清楚的分辨出公式与(表达式、运算式)的差别,遵循用户的指令,将公式以及包含公式的相关文本内容提取出来。'
# },
# }
# ]
#message = [system_prompt, user_prompt]
payload = {
"model": "gpt-4o",
"messages": message,
"max_tokens": 4096,
"temperature": 0.1
}
proxies = {"http": None, "https": None}
response = requests.post(url=end_point, headers=headers, json=payload, timeout=timeout, proxies=proxies)
logger.info(f'gpt4 response={response}')
response_json = response.json()
logger.info('response_json={}'.format(response_json))
#记录日志信息
return_params['input_token']= response_json['usage']['prompt_tokens']
return_params['output_token']= response_json['usage']['completion_tokens']
return_params['status']='Failed'
return_params['api-key']=api_key
return_params['url']=end_point
return_params['model']=response_json['model']
return_params['text']=text
if response.status_code == 200:
return_params['status']='OK'
if 'choices' in response_json and len(response_json['choices']) > 0:
model_reply = response_json['choices'][0]['message']['content']
#logger.info('gpt return infos={}'.format(model_reply))
logger.info('model return success!')
self.write_token_logs(return_params)
return model_reply
else:
logger.info('model_reply is Empty!')
return ''
else:
print("请求服务器错误")
return ''
# def layout_Rec(self,image_path):
# max_try=3
# flag_num=1
# while flag_num<max_try:
# if image_path is None or len(image_path)<=0:
# logger.error('路径错误!image_path={}'.format(image_path))
# return []
# #url = "http://192.168.1.235:30016/v1/dcg_layout"
# url = LAYOUT_CHECK_URL
# file = open(image_path, "rb")
# #logger.info('file={}'.format(file))
# #logger.info('image_path={}'.format(image_path))
# params = {
# "userid": "yxl_110",
# "client_id": "dcg-red-list"
# }
# headers = {"Authorization": "Bearer dcg-MTQ2MDRkYWRmNzRjMDg0ZjZmNTc3YTliMWM0YzYwYmVlZDE="}
# try:
# #proxies = {"http": None, "https": None}
# response = requests.post(url, files={"file": file}, data=params, headers=headers)
# if response.status_code == 200:
# response=response.json()
# return response['data']['Ids_Scores_boxes']
# else:
# flag_num+=1
# time.sleep(1)
# except Exception as e:
# logger.error(f"版面检测调用失败!{e},失败次数={flag_num}")
# logger.error(f"版面检测完全调用失败!")
# return []
# def layout_Rec(self,image_path):
# if image_path is None or len(image_path)<=0:
# logger.error('路径错误!image_path={}'.format(image_path))
# return []
# #url = "http://192.168.1.235:30016/v1/dcg_layout"
# url = LAYOUT_CHECK_URL
# logger.info('in layout_Rec! alrt:api key not save in datasets!')
# file = open(image_path, "rb")
# #logger.info('file={}'.format(file))
# #logger.info('image_path={}'.format(image_path))
# params = {
# "userid": "dcg-kb",
# "client_id": "dcg-red-list"
# }
# headers = {"Authorization": "Bearer dcg-MTQ2MDRkYWRmNzRjMDg0ZjZmNTc3YTliMWM0YzYwYmVlZDE="}
# try:
# response = requests.post(url, files={"file": file}, data=params, headers=headers)
# if response.status_code == 200:
# response=response.json()
# # 获取内部 JSON 字符串的值
# inner_data = json.loads(response['data'])
# Ids_Scores_boxes = json.loads(inner_data['Ids_Scores_boxes'])
# return Ids_Scores_boxes
# else:
# logger.error('layout 检测失败, image_path={}'.format(image_path))
# return []
# except Exception as e:
# logger.error('Layout Detection Error! image_path={},e={}'.format(image_path,e))
# return []
def layout_Rec(self, image_path):
if image_path is None or len(image_path) <= 0:
logger.error('路径错误!image_path={}'.format(image_path))
return []
url = LAYOUT_CHECK_URL
logger.info('in layout_Rec! alrt:api key not save in datasets!')
with open(image_path, "rb") as file:
params = {
"userid": "dcg-kb",
"client_id": "dcg-red-list"
}
headers = {"Authorization": "Bearer dcg-MTQ2MDRkYWRmNzRjMDg0ZjZmNTc3YTliMWM0YzYwYmVlZDE="}
try:
response = requests.post(url, files={"file": file}, data=params, headers=headers)
if response.status_code == 200:
response_data = response.json() # Already a dict, no need for json.loads()
logger.info(f'LayOut_response_data={response_data}')
if 'data' in response_data :
Ids_Scores_boxes=response_data['data']['Ids_Scores_boxes']
return Ids_Scores_boxes
else:
logger.error('Invalid format for data in response.')
else:
logger.error(f'layout 检测失败, image_path={image_path}, status_code={response.status_code}')
except Exception as e:
logger.error(f'Layout Detection Error! image_path={image_path}, e={str(e)}')
return []
def layout_Rec_19(self, image_path):
if image_path is None or len(image_path) <= 0:
logger.error('路径错误!image_path={}'.format(image_path))
return []
url = LAYOUT_CHECK_URL_19
logger.info('in layout_Rec! alrt:api key not save in datasets!')
with open(image_path, "rb") as file:
params = {
"userid": "dcg-kb",
"client_id": "dcg-red-list"
}
headers = {"Authorization": "Bearer dcg-MTQ2MDRkYWRmNzRjMDg0ZjZmNTc3YTliMWM0YzYwYmVlZDE="}
try:
response = requests.post(url, files={"file": file}, data=params, headers=headers)
if response.status_code == 200:
response_data = response.json()["data"] # Already a dict, no need for json.loads()
logger.info(f'LayOut_response_data={response_data}')
if response_data["boxes_num"]>0:
Ids_Scores_boxes=response_data['boxes']
#重新组装内容
box_detection_list=[[1,0.9,box] for box in Ids_Scores_boxes]
return box_detection_list
else:
logger.error('no detection boxes by layout server!')
return []
else:
logger.error(f'layout 检测失败, image_path={image_path}, status_code={response.status_code}')
return []
except Exception as e:
logger.error(f'Layout Detection Error! image_path={image_path}, e={str(e)}')
return []
return []
def get_sub_img_paths(self, image_path, formula_positions, output_folder):
try:
logger.info('in get_sub_img_paths')
logger.info('image_path={},formula_positions={}'.format(image_path, formula_positions))
# 检查输入图像路径是否存在
if not os.path.exists(image_path):
logger.info(f"Image file '{image_path}' does not exist.")
raise FileNotFoundError(f"Image file '{image_path}' does not exist.")
# 读取图像
image = cv2.imread(image_path)
if image is None:
logger.info(f"Failed to read the image file '{image_path}'.")
raise ValueError(f"Failed to read the image file '{image_path}'.")
# 创建输出文件夹
if not os.path.exists(output_folder):
os.makedirs(output_folder)
now_data = self.get_day_time()
output_folder = os.path.join(output_folder, now_data)
if not os.path.exists(output_folder):
os.makedirs(output_folder)
# 存储提取图像路径的列表
extracted_image_paths = []
# 提取感兴趣区域并保存
for i, pos in enumerate(formula_positions):
try:
# 检查感兴趣区域的坐标是否合理
if len(pos) != 4:
logger.info(f"Invalid position data at index {i}: {pos}")
raise ValueError(f"Invalid position data at index {i}: {pos}")
x1, y1, x2, y2 = map(int, pos)
# 向外扩展20个像素,确保不超出图像边界
x1 = max(0, x1 - 10)
y1 = max(0, y1 - 10)
x2 = min(image.shape[1], x2 + 10)
y2 = min(image.shape[0], y2 + 10)
if x1 >= x2 or y1 >= y2:
raise ValueError(f"Invalid coordinates at index {i}: {pos}")
# 提取感兴趣区域
roi = image[y1:y2, x1:x2]
if roi.size == 0:
raise ValueError(f"Empty ROI at index {i}: {pos}")
# 保存提取的图像
output_path = os.path.join(output_folder, f"formula_{i+1}.png")
cv2.imwrite(output_path, roi)
extracted_image_paths.append(output_path)
except Exception as e:
print(f"Error processing position {i}: {e}")
return extracted_image_paths
except Exception as e:
print(f"An error occurred: {e}")
return []
def enlarge_image(self, image, scale_factor=2):
# 获取原图像的尺寸
width = int(image.shape[1] * scale_factor)
height = int(image.shape[0] * scale_factor)
# 使用Lanczos插值方法进行图像放大
enlarged_image = cv2.resize(image, (width, height), interpolation=cv2.INTER_LANCZOS4)
return enlarged_image
def get_sub_img_paths_enhanced(self, image_path, formula_positions, output_folder):
try:
logger.info('in get_sub_img_paths')
logger.info('image_path={},formula_positions={}'.format(image_path, formula_positions))
# 检查输入图像路径是否存在
if not os.path.exists(image_path):
logger.info(f"Image file '{image_path}' does not exist.")
raise FileNotFoundError(f"Image file '{image_path}' does not exist.")
# 读取图像
image = cv2.imread(image_path)
if image is None:
logger.info(f"Failed to read the image file '{image_path}'.")
raise ValueError(f"Failed to read the image file '{image_path}'.")
# 创建输出文件夹
if not os.path.exists(output_folder):
os.makedirs(output_folder)
now_data = self.get_day_time()
output_folder = os.path.join(output_folder, now_data)
if not os.path.exists(output_folder):
os.makedirs(output_folder)
# 存储提取图像路径的列表
extracted_image_paths = []
# 提取感兴趣区域并保存
for i, pos in enumerate(formula_positions):
try:
# 检查感兴趣区域的坐标是否合理
if len(pos) != 4:
logger.info(f"Invalid position data at index {i}: {pos}")
raise ValueError(f"Invalid position data at index {i}: {pos}")
x1, y1, x2, y2 = map(int, pos)
# 向外扩展20个像素,确保不超出图像边界
x1 = max(0, x1 - 10)
y1 = max(0, y1 - 10)
x2 = min(image.shape[1], x2 + 10)
y2 = min(image.shape[0], y2 + 10)
if x1 >= x2 or y1 >= y2:
raise ValueError(f"Invalid coordinates at index {i}: {pos}")
# 提取感兴趣区域
roi = image[y1:y2, x1:x2]
if roi.size == 0:
raise ValueError(f"Empty ROI at index {i}: {pos}")
# 增强提取的图像
enlarged_roi = self.enlarge_image(roi)
logger.info('增强图像')
# 保存提取的图像
output_path = os.path.join(output_folder, f"formula_{i+1}.png")
cv2.imwrite(output_path, enlarged_roi)
extracted_image_paths.append(output_path)
except Exception as e:
print(f"Error processing position {i}: {e}")
return extracted_image_paths
except Exception as e:
print(f"An error occurred: {e}")
return []
def get_sub_img_paths_enhanced_v2(self, image_path, formula_positions, output_folder):
try:
#logger.info('in get_sub_img_paths')
logger.info('image_path={},formula_positions={}'.format(image_path, formula_positions))
# 检查输入图像路径是否存在
if not os.path.exists(image_path):
logger.info(f"Image file '{image_path}' does not exist.")
raise FileNotFoundError(f"Image file '{image_path}' does not exist.")
# 读取图像
image = cv2.imread(image_path)
if image is None:
logger.info(f"Failed to read the image file '{image_path}'.")
raise ValueError(f"Failed to read the image file '{image_path}'.")
# 创建输出文件夹
if not os.path.exists(output_folder):
os.makedirs(output_folder)
now_data = self.get_day_time()
output_folder = os.path.join(output_folder, now_data)
if not os.path.exists(output_folder):
os.makedirs(output_folder)
# 存储提取图像路径的列表
extracted_image_paths = []
# 提取感兴趣区域并保存
for i, pos in enumerate(formula_positions):
try:
# 检查感兴趣区域的坐标是否合理
if len(pos) != 4:
logger.info(f"Invalid position data at index {i}: {pos}")
raise ValueError(f"Invalid position data at index {i}: {pos}")
x1, y1, x2, y2 = map(int, pos)
# 向外扩展20个像素,确保不超出图像边界
x1 = max(0, x1 - 15)
y1 = max(0, y1 - 15)
x2 = min(image.shape[1], x2 + 15)
y2 = min(image.shape[0], y2 + 15)
if x1 >= x2 or y1 >= y2:
raise ValueError(f"Invalid coordinates at index {i}: {pos}")
# 提取感兴趣区域
roi = image[y1:y2, x1:x2]
if roi.size == 0:
raise ValueError(f"Empty ROI at index {i}: {pos}")
# 增强提取的图像
enlarged_roi = self.enlarge_image(roi)
#logger.info('增强图像')
# # 如果子图大小为整张图大小的3/4及以上,对其进行二分
# if (x2 - x1) >= (0.75 * image.shape[1]) and (y2 - y1) >= (0.75 * image.shape[0]):
# logger.info(f"Sub-image at index {i} is too large, splitting into smaller sections.")
# mid_y = y1 + (y2 - y1) // 2
# for j, (start_y, end_y) in enumerate([(y1, mid_y), (mid_y, y2)]):
# sub_roi = image[start_y:end_y, x1:x2]
# sub_enlarged_roi = self.enlarge_image(sub_roi)
# sub_output_path = os.path.join(output_folder, f"formula_{i+1}_{j+1}.png")
# cv2.imwrite(sub_output_path, sub_enlarged_roi)
# extracted_image_paths.append(sub_output_path)
# 如果子图大小为整张图大小的3/4及以上,且子图的宽度或高度大于400像素,对其进行二分
if ((x2 - x1) * (y2 - y1)) >= (0.5 * image.shape[1] * image.shape[0]) and ((x2 - x1) > 600 or (y2 - y1) > 600):
logger.info(f"Sub-image at index {i} is too large, splitting into smaller sections.")
mid_y = y1 + (y2 - y1) // 2
for j, (start_y, end_y) in enumerate([(y1, mid_y), (mid_y, y2)]):
sub_roi = image[start_y:end_y, x1:x2]
sub_enlarged_roi = self.enlarge_image(sub_roi)
sub_output_path = os.path.join(output_folder, f"formula_{i+1}_{j+1}.png")
cv2.imwrite(sub_output_path, sub_enlarged_roi)
extracted_image_paths.append(sub_output_path)
else:
# 保存提取的图像
output_path = os.path.join(output_folder, f"formula_{i+1}.png")
cv2.imwrite(output_path, enlarged_roi)
extracted_image_paths.append(output_path)
except Exception as e:
logger.error(f"Error processing position {i}: {e}")
return extracted_image_paths
except Exception as e:
logger.error(f"An error occurred: {e}")
return []
def get_sub_img_paths_enhanced_v4(self, image_path, formula_positions, output_folder):
try:
# 检查输入图像路径是否存在
if not os.path.exists(image_path):
logger.info(f"Image file '{image_path}' does not exist.")
raise FileNotFoundError(f"Image file '{image_path}' does not exist.")
# 读取图像
image = cv2.imread(image_path)
if image is None:
logger.info(f"Failed to read the image file '{image_path}'.")
raise ValueError(f"Failed to read the image file '{image_path}'.")
now_times_miles=get_millisecond_time()
# 创建输出文件夹
# output_folder=os.path.join(output_folder,str(now_times_miles))
if not os.path.exists(output_folder):
os.makedirs(output_folder)
#unique_file_name=image_path.split('/')[-1].split('.')[0]
#output_folder = os.path.join(output_folder, unique_file_name)
# if not os.path.exists(output_folder):
# os.makedirs(output_folder)
# 存储提取图像路径的列表
extracted_image_paths = []
# 提取感兴趣区域并保存
for i, pos in enumerate(formula_positions):
try:
# 检查感兴趣区域的坐标是否合理
if len(pos) != 4:
logger.info(f"Invalid position data at index {i}: {pos}")
raise ValueError(f"Invalid position data at index {i}: {pos}")
x1, y1, x2, y2 = map(int, pos)
# 向外扩展20个像素,确保不超出图像边界
x1 = max(0, x1 - 20)
y1 = max(0, y1 - 20)
x2 = min(image.shape[1], x2 + 20)
y2 = min(image.shape[0], y2 + 20)
if x1 >= x2 or y1 >= y2:
raise ValueError(f"Invalid coordinates at index {i}: {pos}")
# 提取感兴趣区域
roi = image[y1:y2, x1:x2]
if roi.size == 0:
raise ValueError(f"Empty ROI at index {i}: {pos}")
#enlarged_roi=roi
# 增强提取的图像
enlarged_roi = self.enlarge_image(roi)
output_path = os.path.join(output_folder,f"formula_{i+1}.png")
cv2.imwrite(output_path, enlarged_roi)
extracted_image_paths.append(output_path)
except Exception as e:
logger.error(f"Error processing position {i}: {e}")
return extracted_image_paths
except Exception as e:
logger.error(f"An error occurred: {e}")
return []
def get_sub_img_paths_enhanced_NoClip(self, image_path, formula_positions, output_folder):
try:
#logger.info('in get_sub_img_paths')
logger.info('image_path={},formula_positions={}'.format(image_path, formula_positions))
# 检查输入图像路径是否存在
if not os.path.exists(image_path):
logger.info(f"Image file '{image_path}' does not exist.")
raise FileNotFoundError(f"Image file '{image_path}' does not exist.")
# 读取图像
image = cv2.imread(image_path)
if image is None:
logger.info(f"Failed to read the image file '{image_path}'.")
raise ValueError(f"Failed to read the image file '{image_path}'.")
# 创建输出文件夹
if not os.path.exists(output_folder):
os.makedirs(output_folder)
now_data = self.get_day_time()
output_folder = os.path.join(output_folder, now_data)
if not os.path.exists(output_folder):
os.makedirs(output_folder)
# 存储提取图像路径的列表
extracted_image_paths = []
# 提取感兴趣区域并保存
for i, pos in enumerate(formula_positions):
try:
# 检查感兴趣区域的坐标是否合理
if len(pos) != 4:
logger.info(f"Invalid position data at index {i}: {pos}")
raise ValueError(f"Invalid position data at index {i}: {pos}")
x1, y1, x2, y2 = map(int, pos)
# 向外扩展20个像素,确保不超出图像边界
x1 = max(0, x1 - 5)
y1 = max(0, y1 - 5)
x2 = min(image.shape[1], x2 + 5)
y2 = min(image.shape[0], y2 + 5)
if x1 >= x2 or y1 >= y2:
raise ValueError(f"Invalid coordinates at index {i}: {pos}")
# 提取感兴趣区域
roi = image[y1:y2, x1:x2]
if roi.size == 0:
raise ValueError(f"Empty ROI at index {i}: {pos}")
enlarged_roi=roi
# 增强提取的图像
#enlarged_roi = self.enlarge_image(roi)
#logger.info('增强图像')
# # 如果子图大小为整张图大小的3/4及以上,对其进行二分
# if (x2 - x1) >= (0.75 * image.shape[1]) and (y2 - y1) >= (0.75 * image.shape[0]):
# logger.info(f"Sub-image at index {i} is too large, splitting into smaller sections.")
# mid_y = y1 + (y2 - y1) // 2
# for j, (start_y, end_y) in enumerate([(y1, mid_y), (mid_y, y2)]):
# sub_roi = image[start_y:end_y, x1:x2]
# sub_enlarged_roi = self.enlarge_image(sub_roi)
# sub_output_path = os.path.join(output_folder, f"formula_{i+1}_{j+1}.png")
# cv2.imwrite(sub_output_path, sub_enlarged_roi)
# extracted_image_paths.append(sub_output_path)
# 如果子图大小为整张图大小的3/4及以上,且子图的宽度或高度大于400像素,对其进行二分
# 保存提取的图像
output_path = os.path.join(output_folder, f"formula_{i+1}.png")
cv2.imwrite(output_path, enlarged_roi)
extracted_image_paths.append(output_path)
except Exception as e:
logger.error(f"Error processing position {i}: {e}")
return extracted_image_paths
except Exception as e:
logger.error(f"An error occurred: {e}")
return []
def get_sub_img_paths_enhanced_v3(self, image_path, formula_positions, output_folder):
try:
#logger.info('in get_sub_img_paths')
logger.info('image_path={},formula_positions={}'.format(image_path, formula_positions))
# 检查输入图像路径是否存在
if not os.path.exists(image_path):
logger.info(f"Image file '{image_path}' does not exist.")
raise FileNotFoundError(f"Image file '{image_path}' does not exist.")
# 读取图像
image = cv2.imread(image_path)
if image is None:
logger.info(f"Failed to read the image file '{image_path}'.")
raise ValueError(f"Failed to read the image file '{image_path}'.")
# 创建输出文件夹
if not os.path.exists(output_folder):
os.makedirs(output_folder)
now_data = self.get_day_time()
output_folder = os.path.join(output_folder, now_data)
if not os.path.exists(output_folder):
os.makedirs(output_folder)
# 存储提取图像路径的列表
extracted_image_paths = []
# 提取感兴趣区域并保存
for i, pos in enumerate(formula_positions):
try:
# 检查感兴趣区域的坐标是否合理
if len(pos) != 4:
logger.info(f"Invalid position data at index {i}: {pos}")
raise ValueError(f"Invalid position data at index {i}: {pos}")
x1, y1, x2, y2 = map(int, pos)
# 向外扩展20个像素,确保不超出图像边界
x1 = max(0, x1 - 10)
y1 = max(0, y1 - 15)
x2 = min(image.shape[1], x2 + 10)
y2 = min(image.shape[0], y2 + 15)
if x1 >= x2 or y1 >= y2:
raise ValueError(f"Invalid coordinates at index {i}: {pos}")
# 提取感兴趣区域
roi = image[y1:y2, x1:x2]
if roi.size == 0:
raise ValueError(f"Empty ROI at index {i}: {pos}")
# 增强提取的图像
enlarged_roi = self.enlarge_image(roi)
#logger.info('增强图像')
if ((x2 - x1) * (y2 - y1)) >= (0.75 * image.shape[1] * image.shape[0]) and ((x2 - x1) > 800 or (y2 - y1) > 800):
logger.info('子图裁剪-----------------------')
logger.info(f"Sub-image at index {i} is too large, splitting into smaller sections.")
if (x2 - x1) > (y2 - y1): # 横向裁剪
mid_x = x1 + (x2 - x1) // 2
for j, (start_x, end_x) in enumerate([(x1, mid_x), (mid_x, x2)]):
sub_roi = image[y1:y2, start_x:end_x]
sub_enlarged_roi = self.enlarge_image(sub_roi)
sub_output_path = os.path.join(output_folder, f"formula_{i+1}_{j+1}.png")
cv2.imwrite(sub_output_path, sub_enlarged_roi)
extracted_image_paths.append(sub_output_path)
else: # 竖向裁剪
mid_y = y1 + (y2 - y1) // 2
for j, (start_y, end_y) in enumerate([(y1, mid_y), (mid_y, y2)]):
sub_roi = image[start_y:end_y, x1:x2]
sub_enlarged_roi = self.enlarge_image(sub_roi)
sub_output_path = os.path.join(output_folder, f"formula_{i+1}_{j+1}.png")
cv2.imwrite(sub_output_path, sub_enlarged_roi)
extracted_image_paths.append(sub_output_path)
else:
# 保存提取的图像
output_path = os.path.join(output_folder, f"formula_{i+1}.png")
cv2.imwrite(output_path, enlarged_roi)
extracted_image_paths.append(output_path)
except Exception as e:
logger.error(f"Error processing position {i}: {e}")
return extracted_image_paths
except Exception as e:
logger.error(f"An error occurred: {e}")
return []
def Image2Latex(self, img_path, inference_mode='cuda', num_beam=1, mix=True):
"""
Posts an image to the specified server URL and returns the server's response and the elapsed time.
Parameters:
server_url (str): The URL of the server to post the image to.
img_path (str): The path to the image file to be posted.
Returns:
tuple: A tuple containing the response from the server and the elapsed time for the request.
"""
rec_server_url = "http://localhost:8816/predict/img2latex"
logger.info('Image2Latex URL needed to certain!')
with open(img_path, "rb") as image_file:
files = {"image": image_file}
data = {
"inference_mode": inference_mode,
"num_beam": num_beam,
"mix": str(mix).lower() # 将布尔值转换为字符串形式
}
response = requests.post(rec_server_url, files=files, data=data)
if response.status_code == 200:
logger.info("img2latex成功 预测结果:{}".format(response.json()))
return response.json()
else:
logger.info('img2latex失败'.format(response))
return {'result':''}
#将子图的宽度都拉到同一批数据的统一的值,实现宽度拉伸
def get_sub_img_paths_enhanced_v3(self, image_path, formula_positions, output_folder):
try:
#logger.info('in get_sub_img_paths')
logger.info('image_path={},formula_positions={}'.format(image_path, formula_positions))
# 检查输入图像路径是否存在
if not os.path.exists(image_path):
logger.info(f"Image file '{image_path}' does not exist.")
raise FileNotFoundError(f"Image file '{image_path}' does not exist.")
# 读取图像
image = cv2.imread(image_path)
if image is None:
logger.info(f"Failed to read the image file '{image_path}'.")
raise ValueError(f"Failed to read the image file '{image_path}'.")
# 创建输出文件夹
if not os.path.exists(output_folder):
os.makedirs(output_folder)
now_data = self.get_day_time()
output_folder = os.path.join(output_folder, now_data)
if not os.path.exists(output_folder):
os.makedirs(output_folder)
max_width = 0
for position in formula_positions:
coordinates = position[2]
width = coordinates[2] - coordinates[0]
if width > max_width:
max_width = width
max_width
# 存储提取图像路径的列表
extracted_image_paths = []
# 提取感兴趣区域并保存
for i, pos in enumerate(formula_positions):
try:
# 检查感兴趣区域的坐标是否合理
if len(pos) != 4:
logger.info(f"Invalid position data at index {i}: {pos}")
raise ValueError(f"Invalid position data at index {i}: {pos}")
x1, y1, x2, y2 = map(int, pos)
# 向外扩展20个像素,确保不超出图像边界
x1 = max(0, x1 - 10)
y1 = max(0, y1 - 10)
x2 = min(image.shape[1], x2 + 10)
y2 = min(image.shape[0], y2 + 10)
if x1 >= x2 or y1 >= y2:
raise ValueError(f"Invalid coordinates at index {i}: {pos}")
# 提取感兴趣区域
roi = image[y1:y2, x1:x2]
if roi.size == 0:
raise ValueError(f"Empty ROI at index {i}: {pos}")
# 增强提取的图像
enlarged_roi = self.enlarge_image(roi)
#logger.info('增强图像')
# # 如果子图大小为整张图大小的3/4及以上,对其进行二分
# if (x2 - x1) >= (0.75 * image.shape[1]) and (y2 - y1) >= (0.75 * image.shape[0]):
# logger.info(f"Sub-image at index {i} is too large, splitting into smaller sections.")
# mid_y = y1 + (y2 - y1) // 2
# for j, (start_y, end_y) in enumerate([(y1, mid_y), (mid_y, y2)]):
# sub_roi = image[start_y:end_y, x1:x2]
# sub_enlarged_roi = self.enlarge_image(sub_roi)
# sub_output_path = os.path.join(output_folder, f"formula_{i+1}_{j+1}.png")
# cv2.imwrite(sub_output_path, sub_enlarged_roi)
# extracted_image_paths.append(sub_output_path)
# 如果子图大小为整张图大小的3/4及以上,且子图的宽度或高度大于400像素,对其进行二分
if ((x2 - x1) >= (0.75 * image.shape[1]) or (y2 - y1) >= (0.75 * image.shape[0])) and ((x2 - x1) > 600 and (y2 - y1) > 600):
logger.info(f"Sub-image at index {i} is too large, splitting into smaller sections.")
mid_y = y1 + (y2 - y1) // 2
for j, (start_y, end_y) in enumerate([(y1, mid_y), (mid_y, y2)]):
sub_roi = image[start_y:end_y, x1:x2]
sub_enlarged_roi = self.enlarge_image(sub_roi)
sub_output_path = os.path.join(output_folder, f"formula_{i+1}_{j+1}.png")
cv2.imwrite(sub_output_path, sub_enlarged_roi)
extracted_image_paths.append(sub_output_path)
else:
# 保存提取的图像
output_path = os.path.join(output_folder, f"formula_{i+1}.png")
cv2.imwrite(output_path, enlarged_roi)
extracted_image_paths.append(output_path)
except Exception as e:
logger.error(f"Error processing position {i}: {e}")
return extracted_image_paths
except Exception as e:
logger.error(f"An error occurred: {e}")
return []
# Make sure that self.get_day_time() and self.enlarge_image() methods are properly defined in your class.
# 获取毫秒级时间
def get_millisecond_time(self):
current_time = datetime.now()
time_str = current_time.strftime("%Y%m%d%H%M%S%f")[:-3]
return time_str
def get_day_time(self):
# 获取当前日期和时间
now = datetime.now()
# 格式化日期和时间为字符串,格式为 "YYYYMMDD_HHMMSS"
formatted_time = now.strftime("%Y%m%d_%H%M%S")
return formatted_time
def convert_to_dict(self,all_res):
result = []
for item in all_res:
try:
# 将字符串转换为字典
converted_item = ast.literal_eval(item)
result.append(converted_item)
except Exception as e:
print(f"Error converting item: {item}, error: {e}")
result.append(item) # 保留原始字符串以防转换失败
return result
# def save_to_html():
def filter_detection_results(self,S):
"""
过滤出confidence大于0.8的结果
参数:
- S (list): 检测结果列表
返回:
- bool: 如果过滤后的结果个数大于1,返回True,否则返回False
- list: 过滤后的结果列表
"""
filtered_results = [result for result in S if result['confidence'] > 0.8]
return len(filtered_results) > 1
def annotate_image_with_detections(self,original_image_path, output_image_path, detections):
"""
在图像上绘制检测模型的结果并保存。
参数:
- original_image_path (str): 原始图像文件路径
- output_image_path (str): 保存带注释的图像路径
- detections (list): 检测结果列表,每个元素应包含 'p'(坐标)、'h'(高度)、'w'(宽度)、'label' 和 'confidence'
返回:
- None
"""
image = cv2.imread(original_image_path)
for item in detections:
x = item['p']['x']
y = item['p']['y']
width = item['w']
height = item['h']
confidence = item['confidence']
color = (0, 255, 0) if confidence > 0.8 else (0, 0, 255)
cv2.rectangle(image, (x, y), (x + width, y + height), color, 2)
label_text = f"{item['label']} ({confidence:.2f})"
cv2.putText(image, label_text, (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
cv2.imwrite(output_image_path, image)
def formula_detection_fuc(self,img_path):
"""
处理图像路径,调用检测服务并过滤检测结果。
参数:
- img_path (str): 输入图像路径
返回:
- bool: 是否有足够高置信度的结果
- str: 错误消息(如果有的话)
"""
with open(img_path, 'rb') as img:
files = {'img': img}
try:
response = requests.post(FORMULA_DETECTION_URL, files=files)
except requests.exceptions.RequestException as e:
logger.error(f"请求公式检测服务失败: {e}")
return False, f"请求公式检测服务失败: {e}"
if response.status_code != 200:
logger.error(f"检测服务返回错误: {response.status_code}")
return False, f"检测服务返回错误: {response.status_code}"
try:
response_infos = json.loads(response.text)
except json.JSONDecodeError as e:
logger.error(f"解析检测服务返回结果失败: {e}")
return False, "解析检测服务返回结果失败"
#返回结果
try:
has_valid_detections = self.filter_detection_results(response_infos)
return has_valid_detections,'检查成功'
except json.JSONDecodeError as e:
logger.error(f"解析检测服务返回结果失败: {e}")
return False, "解析检测服务返回结果失败"
def clean_return_results(self,formula_results):
def check_bool(text):
"""
检查LaTeX文本中是否包含大于10的数字。
参数:
text (str): LaTeX文本字符串。
返回:
bool: 如果所有数字都小于等于10,返回True;否则返回False。
"""
if len(text)==0:
return False
# 使用正则表达式匹配数字
numbers = re.findall(r'\b\d+\b', text)
# 检查是否存在大于10的数字
for number in numbers:
if int(number) > 10:
return False
return True
return_results=[]
for block in formula_results:
if check_bool(block['corrected_formula']):
return_results.append(block)
return return_results
def is_break(self, text):
# 定义关键词列表
keywords = [
'证明', '解', '例题', '名师经验谈', '解析', '答案', 'A.','B.','C.','D.'
]
# 检查关键词是否存在于文本中
for keyword in keywords:
if keyword in text:
return False # 如果发现关键词,返回 False
return True # 如果所有关键词都不存在,返回 True
if __name__ == '__main__':
formula_tool=Formula_Checker()
img_path = "/data/wangtengbo/formula_TexTeller/TexTeller/src/9c18ed3e78b34fb7b31c9c73d044fcac.jpg"
result, message = formula_tool.formula_detection_fuc(img_path)
if not result:
logger.info(f"处理结果: {message}")
else:
logger.info("处理成功并生成可视化图像。")
import re
from utils.common import Singleton
# class FormulaProcessor(metaclass=Singleton):
class FormulaProcessor():
def __init__(self):
# 预编译正则表达式,提高性能
self.chinese_characters_pattern = re.compile(r'[\u4e00-\u9fff]')
self.english_characters_pattern = re.compile(r'[a-zA-Z]')
self.digits_pattern = re.compile(r'\d')
self.expression_pattern = re.compile(r'(\b[a-zA-Z]+\b\s*=\s*\d+)')
self.latex_formula_pattern = re.compile(r'\\[a-zA-Z]+\{.*?\}|\$.*?\$')
self.exclusion_keywords = ['例题', '题目','答案','练习','新课教授','化简','例','判断','选择','填空','计算题']
self.allowed_characters_pattern = re.compile(r'^[\u4e00-\u9fffA-Za-z]+$')
def contains_latex_formula(self, string):
return bool(self.latex_formula_pattern.search(string))
def is_formula(self, string):
# 规则1:包含等号或者包含LaTeX公式
if '=' in string or '>' in string or '<' in string or '\\' in string or '+' in string or '-' in string or '*' in string or '^' in string :
return True
if not self.contains_latex_formula(string):
return False
if self.allowed_characters_pattern.match(string):
return False
# 规则3:包含“例题”、“题”、“题目”的都不是公式
if any(keyword in string for keyword in self.exclusion_keywords):
return False
# # 规则2:只有中文、英文,或者同时包含中文以及英文和数字的肯定不是公式,除非包含LaTeX公式
if self.chinese_characters_pattern.search(string) and (
self.english_characters_pattern.search(string) or self.digits_pattern.search(string)):
if not self.contains_latex_formula(string):
return False
# 通过所有规则检查,返回True
return True
# 测试示例
if __name__ == "__main__":
checker = FormulaProcessor()
test_strings = [
"1.了解二次根式、最简二次根式的概念,理解二次根式的性质 2.了解二次根式(根号下仅限于数)的加、减、乘、除运算法则,会用它们进 行有关的简单四则运算. 一 课时分配 本章教学约需8 课时,具体安排如下: 16.1二次根式2 课时 16.2二次根式的乘除3 课时 16.3二次根式的加减2 课时 小结1课时。",
"最简二次根式的概念,理解二次根式的性质",
'abdfsfdsfewfdsfsdfew,'
'2为了民族复兴的梦想,我们从1840年的海面出发'
"x^2 + y^2 = z^2",
"E = mc^2",
"F = ma",
'a*b=10',
"理解并掌握(\sqrt{a})^{2}=a (a \geq 0)",
"本章内容主要有两个部分,它们分别是二次根式的有关概念、性质和二次根 式的四则运算. 本章的第一部分是二次根式的有关概念和性质.教材从两个电视塔的传播 半径之比出发,引入人二次根式的概念.接着根据有理数的算术平方根的意义,顺 理成章地推导出二次根式的两个性质: $\left( \sqrt{a} \right)^{3}=a \left( a \geq 0 \right) \rightarrow \sqrt{a^{3}}=a \left( a \geq 0 \right)$ 本章的第二部分是二次根式的四则运算.教材遵循由特殊到一般的规律,由 学生通过分析、概括、交流、归纳等过程,探究得到二次根式的乘除运算法则: $\sqrt{a}+ \sqrt{b}= \sqrt{ab} \left( a \geq 0,b \geq 0 \right) \text{和} \frac{ \sqrt{a}}{ \sqrt{b}}= \sqrt{ \frac{a}{b}} \left( a \geq 0,b>0 \right)$ .在此基础上,又通过进一步 类比,引出二次根式的加减运算.教材注重知识之间的联系,将乘法公式运用到 二次根式的四则运算中,以简化二次根式的运算"
,
'1.直线与抛物线的交点问题 要解决直线与抛物线的位置关系问题,可把直线方程与抛物线方程联 立,消去y(或消去x)得出关于x(或关于y)的一个方程 $ax^{2}+bx+c=0$ $y$ ,其中二 次项系 $a$ 有可能为0,此时直线与抛物线有一个交点. 当二次项系数 $a \neq 0$ 时, $\Delta=b^{2}-4ac$ 若△=0,则直线与抛物线没有公共点; 若 $\Delta>0$ ,则直线与抛物线有且只有一个公共点; 若 $\Delta<0$ 则直线与抛物线有两个不同的公共点. 2.弦长问题 设弦的端点为 $A \left( x_{1},y_{1} \right),B \left( x_{2},y_{2} \right)$ (1)一般弦长: $\left| AB \right|= \sqrt{1+k^{2}} \left| x_{1}-x_{2} \right.$ 域| $AB \left| \right.= \sqrt{1+ \frac{1}{k^{2}}} \left| y_{1}-y_{2} \right|$ (其中 k为弦所在直线的斜率) (2)焦点弦长: $\left| AB \right|=x_{1}+x_{2}-p.$ 3.中点弦问题 若 $M \left( x_{0},y_{0} \right)$ 是抛物线 $y^{2}=2px \left( p>0 \right)$ 的弦 $AB$ 的中点,则直线 $AB$ 的斜率 为 $k_{AB}= \frac{p}{y_{0}}$'
,
'库伦定律表达式:F=k1y2 r2' #:True
]
results = [checker.is_formula(s) for s in test_strings]
print(results) # [False, False, True, True, True, True, True]
# encoding=utf8
import sys
import os
import time
from loguru import logger
import datetime
from sqlalchemy import create_engine, UniqueConstraint
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import declarative_base
from sqlalchemy import Column, String, Integer, DateTime, Text, SmallInteger
from config.config import MYSQL_DB_URL
engine_knowledge = create_engine(
url=MYSQL_DB_URL,
max_overflow = 10, #超过连接池大小外最多创建的连接,为0表示超过5个连接后,其他连接请求会阻塞 (默认为10)
pool_size = 50, #连接池大小(默认为5)
pool_timeout = 30, #连接线程池中,没有连接时最多等待的时间,不设置无连接时直接报错 (默认为30)
pool_recycle = 3600 #多久之后对线程池中的线程进行一次连接的回收(重置) (默认为-1)
)
Session_Knowledge = sessionmaker(bind=engine_knowledge)
BaseKnowledge = declarative_base()
class KnowledgeBasePrompts(BaseKnowledge):
__tablename__ = 'knowledge-base-prompts'
id = Column(Integer, primary_key=True)
task_id = Column(Integer, unique=True)
task_name = Column(String(50))
prompt = Column(Text)
llm = Column(String(30))
author = Column(String(50))
status = Column(Integer)
attr = Column(Text)
create_time = Column(DateTime, default=datetime.datetime.now)
update_time = Column(DateTime, default=datetime.datetime.now)
class CorrectLLMStatisticsLog(BaseKnowledge):
__tablename__ = 'correct_llm_statistics_log'
id = Column(Integer, primary_key=True)
env = Column(String(32))
source_type = Column(String(32))
server_name = Column(String(64))
user_id = Column(String(64))
publish_id = Column(String(64))
request = Column(String(512))
model = Column(String(512))
url = Column(String(512))
api_key = Column(String(512))
api_version = Column(String(512))
input_token = Column(Integer)
output_token = Column(Integer)
start_time = Column(DateTime)
end_time = Column(DateTime)
message = Column(String(512))
doc_id = Column(String(64))
fragment_id = Column(String(512))
book_name = Column(String(512))
text = Column(Text)
backup = Column(Text)
update_time = Column(DateTime, default=datetime.datetime.now)
"""
类DBUtils的定义
"""
class DBUtils(object):
def __init__(self):
pass
@staticmethod
def get_prompt(task_id, task_name=None):
session = Session_Knowledge()
try:
prompt_obj = session.query(KnowledgeBasePrompts).filter_by(task_id=task_id).first()
return prompt_obj
except Exception as e:
logger.error("get_prompt error: {0}",e)
return None
finally:
session.close()
@staticmethod
def insert_llm_log(env=None, source_type=None, server_name=None, user_id=None, publish_id=None,
request=None, model=None, url=None, api_key=None,api_version=None,
input_token=0, output_token=0, start_time=None, end_time=None,
message=None, doc_id=None, fragment_id=None,book_name=None, text=None,
backup=None):
begin_time = datetime.datetime.now()
session = Session_Knowledge()
try:
log = CorrectLLMStatisticsLog(env=env, source_type=source_type, server_name=server_name,
user_id=user_id, publish_id=publish_id, request=request,
model=model, url=url, api_key=api_key, api_version=api_version,
input_token=input_token, output_token=output_token,
start_time=start_time, end_time=end_time, message=message,
doc_id=doc_id, fragment_id=fragment_id,book_name=book_name, text=text,
backup=backup)
session.add(log)
session.commit()
except Exception as e:
logger.error("insert_llm_log error: {}".format(e))
session.rollback()
finally:
session.close()
over_time = datetime.datetime.now()
logger.info("insert_llm_log elapsed time: {}".format(over_time - begin_time))
def test_dbutil():
pass
if __name__ == '__main__':
test_dbutil()
\ No newline at end of file
import os
from openai import OpenAI
from loguru import logger
from config.config import QWEN_API_KEY,QWEN_URL,QWEN_MODEL
def qwen_vl_infer(
image_url: str,
system_prompt:str,
user_prompt:str
) -> str:
"""
使用指定的多模态模型,对给定图片 URL 进行描述。
Args:
api_key (str): OpenAI API 密钥。
image_url (str): 要描述的图片地址。
model (str): 模型名称,默认为 qwen-vl-max-latest。
base_url (str): 接口基础 URL。
Returns:
str: 模型返回的描述文本;出错时返回空字符串。
"""
try:
client = OpenAI(
api_key=QWEN_API_KEY,
base_url=QWEN_URL,
)
messages = [
{
"role": "system",
"content": [{"type": "text", "text": system_prompt}],
},
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": image_url}},
{"type": "text", "text": user_prompt},
],
},
]
completion = client.chat.completions.create(
model=QWEN_MODEL,
messages=messages,
)
description = completion.choices[0].message.content
logger.info("Received description from model")
return description
except Exception as e:
logger.error(f"Unexpected Qwen Infer error: {e},completion={completion}", exc_info=True)
return ""
\ No newline at end of file
import cv2
import os
class ImageProcessor:
def __init__(self):
pass
def enlarge_image(self, image, scale_factor=2):
# 获取原图像的尺寸
width = int(image.shape[1] * scale_factor)
height = int(image.shape[0] * scale_factor)
# 使用Lanczos插值方法进行图像放大
enlarged_image = cv2.resize(image, (width, height), interpolation=cv2.INTER_LANCZOS4)
return enlarged_image
def process_image(self, image, rois, output_folder):
extracted_image_paths = []
for i, (x1, y1, x2, y2) in enumerate(rois):
roi = image[y1:y2, x1:x2]
enlarged_roi = self.enlarge_image(roi)
if ((x2 - x1) * (y2 - y1)) >= (0.75 * image.shape[1] * image.shape[0]) and ((x2 - x1) > 600 and (y2 - y1) > 600):
# logger.info(f"Sub-image at index {i} is too large, splitting into smaller sections.")
if (x2 - x1) > (y2 - y1): # 横向裁剪
mid_x = x1 + (x2 - x1) // 2
for j, (start_x, end_x) in enumerate([(x1, mid_x), (mid_x, x2)]):
sub_roi = image[y1:y2, start_x:end_x]
sub_enlarged_roi = self.enlarge_image(sub_roi)
sub_output_path = os.path.join(output_folder, f"formula_{i+1}_{j+1}.png")
cv2.imwrite(sub_output_path, sub_enlarged_roi)
extracted_image_paths.append(sub_output_path)
else: # 竖向裁剪
mid_y = y1 + (y2 - y1) // 2
for j, (start_y, end_y) in enumerate([(y1, mid_y), (mid_y, y2)]):
sub_roi = image[start_y:end_y, x1:x2]
sub_enlarged_roi = self.enlarge_image(sub_roi)
sub_output_path = os.path.join(output_folder, f"formula_{i+1}_{j+1}.png")
cv2.imwrite(sub_output_path, sub_enlarged_roi)
extracted_image_paths.append(sub_output_path)
else:
# 保存提取的图像
output_path = os.path.join(output_folder, f"formula_{i+1}.png")
cv2.imwrite(output_path, enlarged_roi)
extracted_image_paths.append(output_path)
return extracted_image_paths
# 假设这里有图像和ROI的初始化代码
if __name__ == "__main__":
processor = ImageProcessor()
image = cv2.imread("/data/wangtengbo/formula_correct/test_data/SQW@MFFBGE28F2%S%TV]M[0.png") # 读取图像
rois = [(50, 50, 700, 1300), (800, 100, 1500, 1600)] # 假设这里有ROI的定义
output_folder = "output"
if not os.path.exists(output_folder):
os.makedirs(output_folder)
extracted_image_paths = processor.process_image(image, rois, output_folder)
print(extracted_image_paths)
import re
class FormulaProcessor:
def __init__(self):
# 预编译正则表达式,提高性能
self.chinese_characters_pattern = re.compile(r'[\u4e00-\u9fff]')
self.english_characters_pattern = re.compile(r'[a-zA-Z]')
self.digits_pattern = re.compile(r'\d')
self.expression_pattern = re.compile(r'(\b[a-zA-Z]+\b\s*=\s*\d+)')
self.latex_formula_pattern = re.compile(r'\\[a-zA-Z]+\{.*?\}|\$.*?\$')
self.exclusion_keywords = ['例题', '题目','答案','讲解','练习','新课教授','化简','例','概念辨析']
def contains_latex_formula(self, string):
return bool(self.latex_formula_pattern.search(string))
def is_formula(self, string):
# 规则1:包含等号或者包含LaTeX公式
if '=' not in string and not self.contains_latex_formula(string):
return False
# 规则3:包含“例题”、“题”、“题目”的都不是公式
if any(keyword in string for keyword in self.exclusion_keywords):
return False
# 规则2:只有中文、英文,或者同时包含中文以及英文和数字的肯定不是公式,除非包含LaTeX公式
if self.chinese_characters_pattern.search(string) and (
self.english_characters_pattern.search(string) or self.digits_pattern.search(string)):
if not self.contains_latex_formula(string):
return False
# 规则4:排除简单的表达式
if re.match(r'[a-zA-Z]\s*=\s*[a-zA-Z0-9+\-*/^]+', string):
return True
# 通过所有规则检查,返回True
return True
# 测试示例
if __name__ == "__main__":
checker = FormulaProcessor()
test_strings = [
"1.了解二次根式、最简二次根式的概念,理解二次根式的性质 2.了解二次根式(根号下仅限于数)的加、减、乘、除运算法则,会用它们进 行有关的简单四则运算. 一 课时分配 本章教学约需8 课时,具体安排如下: 16.1二次根式2 课时 16.2二次根式的乘除3 课时 16.3二次根式的加减2 课时 小结1课时。",
"最简二次根式的概念,理解二次根式的性质",
"x^2 + y^2 = z^2",
"E = mc^2",
"F = ma",
'a*b=10',
"理解并掌握(\sqrt{a})^{2}=a (a \geq 0)",
"本章内容主要有两个部分,它们分别是二次根式的有关概念、性质和二次根 式的四则运算. 本章的第一部分是二次根式的有关概念和性质.教材从两个电视塔的传播 半径之比出发,引入人二次根式的概念.接着根据有理数的算术平方根的意义,顺 理成章地推导出二次根式的两个性质: $\left( \sqrt{a} \right)^{3}=a \left( a \geq 0 \right) \rightarrow \sqrt{a^{3}}=a \left( a \geq 0 \right)$ 本章的第二部分是二次根式的四则运算.教材遵循由特殊到一般的规律,由 学生通过分析、概括、交流、归纳等过程,探究得到二次根式的乘除运算法则: $\sqrt{a}+ \sqrt{b}= \sqrt{ab} \left( a \geq 0,b \geq 0 \right) \text{和} \frac{ \sqrt{a}}{ \sqrt{b}}= \sqrt{ \frac{a}{b}} \left( a \geq 0,b>0 \right)$ .在此基础上,又通过进一步 类比,引出二次根式的加减运算.教材注重知识之间的联系,将乘法公式运用到 二次根式的四则运算中,以简化二次根式的运算"
,
'1.直线与抛物线的交点问题 要解决直线与抛物线的位置关系问题,可把直线方程与抛物线方程联 立,消去y(或消去x)得出关于x(或关于y)的一个方程 $ax^{2}+bx+c=0$ $y$ ,其中二 次项系 $a$ 有可能为0,此时直线与抛物线有一个交点. 当二次项系数 $a \neq 0$ 时, $\Delta=b^{2}-4ac$ 若△=0,则直线与抛物线没有公共点; 若 $\Delta>0$ ,则直线与抛物线有且只有一个公共点; 若 $\Delta<0$ 则直线与抛物线有两个不同的公共点. 2.弦长问题 设弦的端点为 $A \left( x_{1},y_{1} \right),B \left( x_{2},y_{2} \right)$ (1)一般弦长: $\left| AB \right|= \sqrt{1+k^{2}} \left| x_{1}-x_{2} \right.$ 域| $AB \left| \right.= \sqrt{1+ \frac{1}{k^{2}}} \left| y_{1}-y_{2} \right|$ (其中 k为弦所在直线的斜率) (2)焦点弦长: $\left| AB \right|=x_{1}+x_{2}-p.$ 3.中点弦问题 若 $M \left( x_{0},y_{0} \right)$ 是抛物线 $y^{2}=2px \left( p>0 \right)$ 的弦 $AB$ 的中点,则直线 $AB$ 的斜率 为 $k_{AB}= \frac{p}{y_{0}}$'
,
'库伦定律表达式:F=k1y2 r2' #:True
]
results = [checker.is_formula(s) for s in test_strings]
print(results) # [False, False, True, True, True, True, True, True, True]
import requests
import json
from config.config import APP_ID
from config.config import SECRET_CODE
from loguru import logger
def get_file_content(filePath):
with open(filePath, 'rb') as fp:
return fp.read()
class TextinOcr(object):
def __init__(self):
self.host = 'https://api.textin.com'
def recognize_pdf2md(self, image_path, options=None):
"""
pdf to markdown
:param options: request params
:param image: file bytes
:return: response
options = {
'pdf_pwd': None,
'dpi': 144, # 设置dpi为144
'page_start': 0,
'page_count': 1000, # 设置解析的页数为1000页
'apply_document_tree': 0,
'markdown_details': 1,
'page_details': 0, # 不包含页面细节信息
'table_flavor': 'md',
'get_image': 'none',
'parse_mode': 'scan', # 解析模式设为scan
}
"""
image=get_file_content(image_path)
if options==None:
options={
'table_flavor': 'md',
'parse_mode': 'scan', # 设置解析模式为scan模式
'page_details': 1, # 不包含页面细节
'markdown_details': 1,
'apply_document_tree': 1,
'dpi': 144 # 分辨率设置为144 dpi
}
url = self.host + '/ai/service/v1/pdf_to_markdown'
headers = {
'x-ti-app-id': APP_ID,
'x-ti-secret-code': SECRET_CODE
}
response=requests.post(url, data=image, headers=headers, params=options)
#logger.info(f'textln response=\n{response}')
#logger.info(f'Textln response infos={response}')
if response.status_code == 200:
time_cost=response.elapsed.total_seconds()
result = json.loads(response.text)
#logger.info(f'textln_init_infos={result}')
logger.info(f'textln response_time_cost={time_cost}\n\ntextln response=\n{result}')
return result['result']['markdown'],time_cost,result['result']['detail']
else:
logger.info('TextinOcr 请求失败 ,错误信息={}'.format(response))
return [],'',[]
if __name__ == "__main__":
# 请登录后前往 “工作台-账号设置-开发者信息” 查看 app-id/app-secret
textin = TextinOcr()
resp ,time_cost= textin.recognize_pdf2md(image_path='/data/wangtengbo/got_ocr2/infer/QQ图片20240926223216.png')
print("request time: ", time_cost)
print(resp)
# result = json.loads(resp.text)
# print(result)
# with open('./result.json', 'w', encoding='utf-8') as fw:
# json.dump(result, fw, indent=4, ensure_ascii=False)
# import re
# def filter_sentences_by_caption(context_info: str, caption_title: str, threshold: float = 0.5):
# # 1. 按照句号、问号、感叹号切分句子,并去除空白
# sentences = [s.strip() for s in re.split(r'[。!?]', context_info) if s.strip()]
# # 2. 统计 caption_title 中的所有唯一字符
# title_chars = set(caption_title)
# total_chars = len(title_chars)
# # 3. 计算每句中出现的 title 字符比率
# filtered = []
# for sent in sentences:
# # 计算交集字符数
# match_count = sum(1 for ch in title_chars if ch in sent)
# ratio = match_count / total_chars
# # 如果比例超过阈值,则保留
# if ratio > threshold:
# filtered.append((sent, ratio))
# return filtered
# # 示例调用
# txt = "由于他的功力浑厚,画路广阔,所以成就的方面甚多。如在意境开拓之上,他有金蝉脱壳“夺人”之法,同时,他亦服膺“画你最熟悉的”的大原则,画了不少他最亲切的农业社会的亲切图画。如《闻铃心喜图》(28-7)上画一小牧童,身系一铃铛,牵一老牛回家,画上题字云:"
# caption = "图  28-7 近现代  齐白石《闻铃心喜图》"
# results = filter_sentences_by_caption(txt, caption)
# for sent, ratio in results:
# print(f"句子: {sent}\n匹配比例: {ratio:.2f}\n")
# import re
# import logging
# # 设置日志
# logger = logging.getLogger()
# # 示例数据
# item = {'caption_text': "图 5 显示了…"} # 示例文本
# #(1) 提取图序
# image_id = ''
# if len(item['caption_text']) > 3:
# # 匹配图序:图 K、图K-1、图K-1-2 等,去掉空格
# image_id_match = re.findall(r"图\s*\d+(?:-\d+)*", item['caption_text'])
# # 去掉空格
# if image_id_match:
# image_id = image_id_match[0].replace(" ", "") # 输出: 图5 或 图5-7
# else:
# image_id = ''
# # 打印或记录日志
# logger.info(f'{item["caption_text"]} image_order is {image_id}')
# print(f"{image_id}")
# from config.config import VLM_Match_Context_User_Prompt
# from tasks.qwen_vl_infer import qwen_vl_infer
# print(VLM_Match_Context_User_Prompt.replace("{{user_text}}","wwadwadwa ").replace("{{caption}}","wwadwadwa "))
# qwen_match_response=qwen_vl_infer("https://oss.raysgo.com/oss/upload/image/png/777a14d621b3402daa2dbe7fac09d8cf.png",'你是一个图文匹配判断专家。',VLM_Match_Context_User_Prompt.replace("{{user_text}}","wwadwadwa ").replace("{{caption}}","wwadwadwa "))
# print(qwen_match_response)
import re
# 示例文本
text = """
匹配结果:不匹配
原因:文本描述为“一颗大树”,而图像中展示的是一个金色的鸟笼,内有一只小鸟,笼子上覆盖着粉色布料。文本与图像内容完全不符。
分析过程:
1. **观察图像**:图像中有一个金色的鸟笼,笼内有一只小鸟,笼子上覆盖着粉色布料。背景为浅蓝色,没有树木或其他自然元素。
2. **解析文本**:文本仅提到“一颗大树”,没有关于鸟笼、小鸟或粉色布料的描述。
3. **对比匹配**:
- **实体名称**:图像中的实体是“鸟笼”和“小鸟”,而文本中的实体是“大树”。两者名称完全不同。
- **场景环境**:图像场景为室内或简单背景,而文本描述的“大树”通常关联户外自然环境。
4. **判断标准**:根据上述对比,实体名称和场景环境均不匹配,因此判定为不匹配。
综上所述,文本与图像内容存在根本性差异,故判定为不匹配。
"""
# 修改后的正则表达式,支持中英文冒号
pattern = r'### 匹配结果[::]\s*([^\n]+)\n+### 原因[::]\s*\n([^\n]+)\n+### 分析过程:\n*'
# 使用 re.search 提取
match = re.search(pattern, text)
if match:
result = match.group(1) # 匹配结果
reason = match.group(2) # 原因
print(f"匹配结果: {result}")
print(f"原因: {reason}")
else:
print("未能匹配到内容")
print(123021)
# import re
# def normalize_latex_string(s: str) -> str:
# # 1. 去除所有空格和常见的LaTeX空白指令
# s = s.replace(" ", "")
# s = re.sub(r'\\[ ,;!]', '', s) # 移除 \, \; \! 等空白控制命令
# # 2. 标准化上下标形式
# # 将 ^{...} -> ^... 以及类似的下标
# # 如遇到更复杂的内容可根据实际情况调整
# s = re.sub(r'\^\{\s*(.*?)\s*\}', r'^\1', s)
# s = re.sub(r'\_\{\s*(.*?)\s*\}', r'_\1', s)
# # 对于简单的单字符上标下标,如 I ^ 2,若存在空格已去除,不需要额外处理
# # 如果还存在多余花括号,现在统一处理多余的花括号
# s = s.replace("{", "").replace("}", "")
# # 3. 处理特殊算子、函数
# # \cdot、\times 用 * 表示
# s = s.replace(r'\cdot', '*')
# s = s.replace(r'\times', '*')
# # 4. 处理分数 \frac{a}{b} -> (a/b)
# s = re.sub(r'\\frac\s*\((.*?)\)\s*\((.*?)\)', r'(\1/\2)', s) # 若已无花括号,可用此
# # 如果原式有花括号分隔,例如 \frac{a}{b}:
# s = re.sub(r'\\frac\s*(\S+)\s*(\S+)', r'(\1/\2)', s)
# # 若 \frac{a}{b} 中 a,b 本来就是用 {} 包围,可在上面统一处理后再替换
# # 上面已去除花括号,所以可以匹配更简单
# # 由于上面对frac的匹配可能需要更严谨的处理,可考虑:
# # 利用一种更强的regex匹配 \frac{a}{b},其中a,b为一组字符
# # 先恢复一部分逻辑(若frac的参数在花括号中)
# s = re.sub(r'\\frac\s*\((.*?)\)\s*\((.*?)\)', r'(\1/\2)', s)
# s = re.sub(r'\\frac(\w+)(\w+)', r'(\1/\2)', s)
# # 再尝试匹配剩余可能的 \frac{...}{...} 模式:
# s = re.sub(r'\\frac(\S+)\(\S+/\S+\)', lambda m: '(' + m.group(1)[1:-1] + '/' + m.group(2)[1:-1] + ')', s)
# # 由于各种复杂情况,这里已较复杂,实际使用请根据具体格式定制化。
# # 简化思路:匹配 \frac{...}{...}
# s = re.sub(r'\\frac\s*\(?([^\(\)]+)\)?\s*\(?([^\(\)]+)\)?', r'(\1/\2)', s)
# # 5. \sqrt{x} -> sqrt(x)
# s = re.sub(r'\\sqrt\s*(\S+)', r'sqrt(\1)', s)
# # 6. 去除 \left \right 这类大小调整命令
# s = s.replace(r'\left', '')
# s = s.replace(r'\right', '')
# # 7. 移除非必要的命令,如 \mathrm, \mathbf, \text, \displaystyle 等
# s = re.sub(r'\\mathrm|\\mathbf|\\text|\\displaystyle|\\normalfont|\\rm', '', s)
# # 8. 去除残余的反斜杠和空括号(如有)
# # 若无特定命令剩下的裸 \ 符号不常见,直接删除
# s = s.replace("\\", "")
# #再次清除多余空格(万一有)
# s = s.strip()
# return s
# def latex_match(A: str, B: str) -> bool:
# A_norm = normalize_latex_string(A)
# B_norm = normalize_latex_string(B)
# return A_norm in B_norm
def cal_details_path(text,info_dict):
for path_key,sub_details in info_dict.items():
if text == sub_details:
return path_key
return None
# if __name__=="__main__":
# # t=[{'type': 'paragraph', 'tags': [], 'paragraph_id': 0, 'page_id': 1, 'content': 0, 'position': [212, 51, 1987, 51, 1987, 120, 212, 120], 'outline_level': 1, 'text': 'GAOZHONG WULI GONGSHI DINGLI DINGLU TUBIAO'}, {'type': 'paragraph', 'tags': [], 'paragraph_id': 1, 'page_id': 1, 'content': 0, 'position': [74, 143, 1240, 143, 1240, 241, 74, 241], 'outline_level': -1, 'text': '高中物理公式、定理、定律图表'}, {'paragraph_id': 2, 'page_id': 1, 'content': 0, 'outline_level': -1, 'text': ' ', 'type': 'image', 'position': [51, 241, 206, 235, 212, 379, 57, 384], 'image_url': 'https://web-api.textin.com/ocr_image/external/88392321f41f4542.jpg', 'sub_type': 'image'}, {'paragraph_id': 3, 'page_id': 1, 'content': 0, 'outline_level': -1, 'text': '四、焦耳定律 ', 'type': 'image', 'position': [1286, 379, 2429, 373, 2440, 1177, 1292, 1177], 'image_url': 'https://web-api.textin.com/ocr_image/external/49492c33f0ca4aaf.jpg', 'sub_type': 'image'}, {'type': 'paragraph', 'tags': [], 'paragraph_id': 4, 'page_id': 1, 'content': 0, 'position': [287, 1275, 890, 1275, 890, 1378, 287, 1378], 'outline_level': 1, 'text': '一、知识图解'}, {'paragraph_id': 5, 'page_id': 1, 'content': 0, 'outline_level': -1, 'text': '电功: W=qU=UI 、电功率: P=IU 电功与电热 焦耳定律: Q = I ^ { 2 } R 、 热功率: P _ { 热 } = I ^ { 2 } R ', 'type': 'image', 'position': [890, 1470, 2934, 1447, 2934, 1860, 890, 1883], 'image_url': 'https://web-api.textin.com/ocr_image/external/81951d40a79f0c89.jpg', 'sub_type': 'image'}, {'type': 'paragraph', 'tags': [], 'paragraph_id': 6, 'page_id': 1, 'content': 0, 'position': [292, 1906, 1085, 1906, 1085, 2010, 292, 2010], 'outline_level': 1, 'text': '二、重要知识剖析'}, {'type': 'paragraph', 'tags': [], 'paragraph_id': 7, 'page_id': 1, 'content': 0, 'position': [246, 2113, 827, 2113, 827, 2193, 246, 2193], 'outline_level': -1, 'text': '区分电功和电热'}]
# # info_dict={'/data/wangtengbo/Deployments_Formula_Checker_v5_复现并微调_v7.0-合和OCR-没有版面/app/logs_data/sub_images/20241210152612700/formula_1.png': [{'type': 'paragraph', 'tags': [], 'paragraph_id': 0, 'page_id': 1, 'content': 0, 'position': [212, 51, 1987, 51, 1987, 120, 212, 120], 'outline_level': 1, 'text': 'GAOZHONG WULI GONGSHI DINGLI DINGLU TUBIAO'}, {'type': 'paragraph', 'tags': [], 'paragraph_id': 1, 'page_id': 1, 'content': 0, 'position': [74, 143, 1240, 143, 1240, 241, 74, 241], 'outline_level': -1, 'text': '高中物理公式、定理、定律图表'}, {'paragraph_id': 2, 'page_id': 1, 'content': 0, 'outline_level': -1, 'text': ' ', 'type': 'image', 'position': [51, 241, 206, 235, 212, 379, 57, 384], 'image_url': 'https://web-api.textin.com/ocr_image/external/88392321f41f4542.jpg', 'sub_type': 'image'}, {'paragraph_id': 3, 'page_id': 1, 'content': 0, 'outline_level': -1, 'text': '四、焦耳定律 ', 'type': 'image', 'position': [1286, 379, 2429, 373, 2440, 1177, 1292, 1177], 'image_url': 'https://web-api.textin.com/ocr_image/external/49492c33f0ca4aaf.jpg', 'sub_type': 'image'}, {'type': 'paragraph', 'tags': [], 'paragraph_id': 4, 'page_id': 1, 'content': 0, 'position': [287, 1275, 890, 1275, 890, 1378, 287, 1378], 'outline_level': 1, 'text': '一、知识图解'}, {'paragraph_id': 5, 'page_id': 1, 'content': 0, 'outline_level': -1, 'text': '电功: W=qU=UI 、电功率: P=IU 电功与电热 焦耳定律: Q = I ^ { 2 } R 、 热功率: P _ { 热 } = I ^ { 2 } R ', 'type': 'image', 'position': [890, 1470, 2934, 1447, 2934, 1860, 890, 1883], 'image_url': 'https://web-api.textin.com/ocr_image/external/81951d40a79f0c89.jpg', 'sub_type': 'image'}, {'type': 'paragraph', 'tags': [], 'paragraph_id': 6, 'page_id': 1, 'content': 0, 'position': [292, 1906, 1085, 1906, 1085, 2010, 292, 2010], 'outline_level': 1, 'text': '二、重要知识剖析'}, {'type': 'paragraph', 'tags': [], 'paragraph_id': 7, 'page_id': 1, 'content': 0, 'position': [246, 2113, 827, 2113, 827, 2193, 246, 2193], 'outline_level': -1, 'text': '区分电功和电热'}], '/data/wangtengbo/Deployments_Formula_Checker_v5_复现并微调_v7.0-合和OCR-没有版面/app/logs_data/sub_images/20241210152612700/formula_2.png': [{'type': 'paragraph', 'tags': [], 'paragraph_id': 0, 'page_id': 1, 'content': 0, 'position': [712, 7, 726, 7, 726, 0, 719, 7], 'outline_level': -1, 'text': '-'}, {'type': 'paragraph', 'tags': [], 'paragraph_id': 1, 'page_id': 1, 'content': 0, 'position': [430, 0, 712, 0, 712, 14, 430, 14], 'outline_level': -1, 'text': '---- -'}, {'type': 'paragraph', 'tags': [], 'paragraph_id': 3, 'page_id': 1, 'content': 0, 'position': [529, 7, 529, 0, 529, 14, 536, 14], 'outline_level': -1, 'text': '-'}, {'type': 'paragraph', 'tags': [], 'paragraph_id': 6, 'page_id': 1, 'content': 0, 'position': [183, 49, 3781, 49, 3781, 373, 183, 373], 'outline_level': -1, 'text': '电流做功的过程,就是把电能转化为其他形式能的过程.但有些电路元件,只将电能转化为内能,有些电路元件,电流做功将电能一部分转化为内能,还有一部分转化为其他形式的能,如机械能或化学能等.'}, {'type': 'paragraph', 'tags': ['formula'], 'paragraph_id': 7, 'page_id': 1, 'content': 0, 'position': [183, 409, 3774, 409, 3774, 754, 183, 754], 'outline_level': -1, 'text': '在纯电阻电路中:电能全部转化为内能,电功和电热相等,电功率和热功率相等.在非纯电阻电路中,电路消耗的电能W=UIt分为两部分,一大部分转化为其他形式的能;另一部分转化为内能.此时有W&gt;Q,故.此时电功只能用W=UIt计算,电热只能用$Q = I ^ { 2 } R t$计算.'}, {'type': 'paragraph', 'tags': [], 'paragraph_id': 8, 'page_id': 1, 'content': 0, 'position': [402, 846, 1192, 839, 1192, 945, 402, 952], 'outline_level': 1, 'text': '三、学习方法引导'}, {'type': 'paragraph', 'tags': ['formula', 'formula', 'formula'], 'paragraph_id': 9, 'page_id': 1, 'content': 0, 'position': [232, 1213, 994, 1213, 994, 2173, 232, 2173], 'outline_level': -1, 'text': '**名师经验谈:**①电动机是非纯电阻电路,要注意区分电功率和热功率,电功率用$P _ { 总 } = U I$ 计算,热功率$P _ { 热 } = P R$,机械功率$P _ { 机 } = P _ { 总 } - P$热·②注意搞清电动机电路中的能量转化关系.用能量守恒定律来分析问题.'}, {'type': 'paragraph', 'tags': [], 'paragraph_id': 10, 'page_id': 1, 'content': 0, 'position': [1319, 1135, 2744, 1135, 2744, 1220, 1319, 1220], 'outline_level': -1, 'text': '题型 区分纯电阻电路和非纯电阻电路'}, {'paragraph_id': 11, 'page_id': 1, 'content': 0, 'outline_level': -1, 'text': 'M A 电动机 ', 'type': 'image', 'position': [3139, 1206, 3774, 1206, 3774, 1658, 3139, 1658], 'image_url': 'https://web-api.textin.com/ocr_image/external/f72f9fe6ab6a37bb.jpg', 'sub_type': 'image'}, {'type': 'paragraph', 'tags': [], 'paragraph_id': 13, 'page_id': 1, 'content': 0, 'position': [1319, 1255, 3069, 1255, 3069, 1707, 1319, 1707], 'outline_level': -1, 'text': '**例题** 汽车电动机启动时车灯会瞬时变暗,如图所示,在打开车灯的情况下,电动机未启动时电流表读数为10A,电动机启动时电流表读数为58A、若电源电动势为12.5V,内阻'}, {'type': 'paragraph', 'tags': [], 'paragraph_id': 14, 'page_id': 1, 'content': 0, 'position': [1319, 1735, 3781, 1735, 3781, 1961, 1319, 1961], 'outline_level': -1, 'text': '为0.05Ω,电流表内阻不计,则因电动机启动,车灯的电功率降低了( )'}, {'type': 'paragraph', 'tags': [], 'paragraph_id': 15, 'page_id': 1, 'content': 0, 'position': [1319, 1982, 3690, 1982, 3690, 2067, 1319, 2067], 'outline_level': -1, 'text': 'A.35.8 W B.43.2 W C.48.2 W D.76.8 W'}, {'type': 'paragraph', 'tags': ['formula', 'formula', 'formula', 'formula', 'formula', 'formula', 'formula', 'formula', 'formula'], 'paragraph_id': 18, 'page_id': 1, 'content': 0, 'position': [1312, 2095, 3753, 2095, 3753, 3090, 1312, 3090], 'outline_level': -1, 'text': '**解析:**假设r为电源内阻,R 为车灯电阻,$I _ { 1 }$为未启动电动机时流过电流表的电流.在没启动电动机时,满足闭合电路的欧姆定律,$r + R _ { 火 }$ $k J = \\frac { E } { I _ { 1 } }$得: $R _ { 灯 } = \\frac { E } { I _ { 1 } } - r = 1 . 2 \\Omega$ 此时车灯功率为: $P _ { 灯 1 } = 1 2 0 W .$启动电动机后,流过电流表的电流$I _ { 2 } = 5 8$A,此时车灯两端电压为:$U _ { 灯 2 } = E - I _ { 2 } r = 9 . 6 V$,此时车灯功率为$P _ { 灯 2 } = \\frac { U ^ { 2 } k J 2 } { R _ { 灯 } } = \\frac { 9 . 6 ^ { 2 } } { 1 . 2 } W = 7 6 . 8 W .$电动机启动后车灯功率减少了$\\Delta P = P _ { 灯 1 } - P _ { 灯 2 } = 4 3 . 2$W 答案 **B**'}, {'paragraph_id': 19, 'page_id': 1, 'tags': [], 'outline_level': -1, 'text': '66', 'type': 'paragraph', 'position': [84, 3443, 190, 3443, 190, 3527, 84, 3527], 'content': 1, 'sub_type': 'footer'}]}
# # print(cal_details_path(t,info_dict))
# text_A="""$\\frac { \\sqrt { a } } { \\sqrt { b } } = \\sqrt { \\frac { a } { b } } ( a > 0 )$·"""
# text_B=r"\frac { \sqrt { a } } { \sqrt { b } } = \sqrt { \frac { a } { b } } ( a > 0 )"
# print(latex_match(text_B,text_A)) #False
import re
def normalize_latex_string(s: str) -> str:
# 1. 去除空格和数学环境符号
s = s.replace(" ", "").replace("$", "")
s = re.sub(r'\\[ ,;!]', '', s) # 移除 \, \; \! 等
# 2. 处理上标和下标
s = re.sub(r'\^\{\s*(.*?)\s*\}', r'^\1', s)
s = re.sub(r'\_\{\s*(.*?)\s*\}', r'_\1', s)
s = s.replace("{", "").replace("}", "") # 去除剩余花括号
# 3. 替换算子和特殊符号
s = s.replace(r'\cdot', '*').replace(r'\times', '*')
# 4. 处理分数
s = re.sub(r'\\frac\s*\{([^{}]+)\}\s*\{([^{}]+)\}', r'(\1/\2)', s)
# 5. 处理平方根
s = re.sub(r'\\sqrt\s*\{([^{}]+)\}', r'sqrt(\1)', s)
# 6. 移除 \left 和 \right
s = s.replace(r'\left', '').replace(r'\right', '')
# 7. 移除格式化命令
s = re.sub(r'\\mathrm|\\mathbf|\\text|\\displaystyle|\\normalfont|\\rm', '', s)
# 8. 移除残余反斜杠
s = s.replace("\\", "")
# 去除多余空格
return s.strip()
def latex_match(A: str, B: str) -> bool:
A_norm = normalize_latex_string(A)
B_norm = normalize_latex_string(B)
#print("Normalized A:", A_norm)
#print("Normalized B:", B_norm)
return A_norm in B_norm
if __name__ == "__main__":
text_B = r"\frac { \sqrt { 8 } } { \sqrt { 2 a } } = \frac { \sqrt { 8 } \cdot \sqrt { 2 a } } { \sqrt { 2 a } \cdot \sqrt { 2 a } } = \frac { 4 \sqrt { a } } { 2 a } = \frac { 2 \sqrt { a } } { a }"
text_A = r"'(3)$\\frac { \\sqrt { 8 } } { \\sqrt { 2 a } } = \\frac { \\sqrt { 8 } \\cdot \\sqrt { 2 a } } { \\sqrt { 2 a } \\cdot \\sqrt { 2 a } } = \\frac { 4 \\sqrt { a } } { 2 a } = \\frac { 2 \\sqrt { a } } { a } .$"
print(latex_match(text_B, text_A)) # 应该返回 True
import re
# 标准化策略1:基础标准化
def normalize_strategy_1(s: str) -> str:
s = s.replace(" ", "").replace("$", "")
s = re.sub(r'\\[ ,;!]', '', s) # 移除 \, \; \! 等
s = re.sub(r'\^\{\s*(.*?)\s*\}', r'^\1', s)
s = re.sub(r'\_\{\s*(.*?)\s*\}', r'_\1', s)
s = s.replace("{", "").replace("}", "")
s = s.replace(r'\cdot', '*').replace(r'\times', '*')
s = re.sub(r'\\frac\s*\{([^{}]+)\}\s*\{([^{}]+)\}', r'(\1/\2)', s)
s = re.sub(r'\\sqrt\s*\{([^{}]+)\}', r'sqrt(\1)', s)
s = s.replace(r'\left', '').replace(r'\right', '')
s = re.sub(r'\\mathrm|\\mathbf|\\text|\\displaystyle|\\normalfont|\\rm', '', s)
s = s.replace("\\", "")
return s.strip()
# 标准化策略2:更宽松的标准化
def normalize_strategy_2(s: str) -> str:
s = re.sub(r'\\[a-zA-Z]+', '', s) # 去掉所有 LaTeX 命令
s = re.sub(r'[^a-zA-Z0-9^_*/().]', '', s) # 保留基本符号
return s.strip()
# 标准化策略3:保留部分数学符号
def normalize_strategy_3(s: str) -> str:
s = s.replace(" ", "").replace("$", "")
s = re.sub(r'\\frac\s*\{([^{}]+)\}\s*\{([^{}]+)\}', r'(\1/\2)', s)
s = re.sub(r'\\sqrt\s*\{([^{}]+)\}', r'sqrt(\1)', s)
s = s.replace("\\", "")
return s.strip()
# 标准化策略4:完全删除所有 LaTeX 特殊符号,仅保留文字和数字
def normalize_strategy_4(s: str) -> str:
s = re.sub(r'\\[a-zA-Z]+', '', s) # 去掉所有 LaTeX 命令
s = re.sub(r'[{}^_\\]', '', s) # 去掉特殊符号
s = s.replace(" ", "").replace("$", "")
return s.strip()
# 主函数
def normalize_latex_string(s: str, strategy: int) -> str:
if strategy == 1:
return normalize_strategy_1(s)
elif strategy == 2:
return normalize_strategy_2(s)
elif strategy == 3:
return normalize_strategy_3(s)
elif strategy == 4:
return normalize_strategy_4(s)
else:
raise ValueError("Invalid normalization strategy")
# 匹配函数
def latex_match(A: str, B: str) -> bool:
for strategy in range(1, 5): # 尝试所有标准化策略
A_norm = normalize_latex_string(A, strategy)
B_norm = normalize_latex_string(B, strategy)
if A_norm in B_norm:
return True
return False
# 测试代码
if __name__ == "__main__":
text_B = r"\frac { \sqrt { 8 } } { \sqrt { 2 a } } = \frac { \sqrt { 8 } \cdot \sqrt { 2 a } } { \sqrt { 2 a } \cdot \sqrt { 2 a } } = \frac { 4 \sqrt { a } } { 2 a } = \frac { 2 \sqrt { a } } { a }"
text_A = r"'(3)$\\frac { \\sqrt { 8 } } { \\sqrt { 2 a } } = \\frac { \\sqrt { 8 } \\cdot \\sqrt { 2 a } } { \\sqrt { 2 a } \\cdot \\sqrt { 2 a } } = \\frac { 4 \\sqrt { a } } { 2 a } = \\frac { 2 \\sqrt { a } } { a } .$"
print(latex_match(text_B, text_A)) # 应该返回 True
import re
def normalize_latex_string(s: str) -> str:
"""标准化 LaTeX 字符串"""
s = s.replace(" ", "").replace("$", "")
s = re.sub(r'\\[ ,;!]', '', s) # 移除空白控制符
s = re.sub(r'\^\{\s*(.*?)\s*\}', r'^\1', s) # 处理上标
s = re.sub(r'\_\{\s*(.*?)\s*\}', r'_\1', s) # 处理下标
s = s.replace("{", "").replace("}", "") # 移除花括号
s = s.replace(r'\cdot', '*').replace(r'\times', '*') # 替换乘号
s = re.sub(r'\\frac\s*\{([^{}]+)\}\s*\{([^{}]+)\}', r'(\1/\2)', s) # 处理分数
s = re.sub(r'\\sqrt\s*\{([^{}]+)\}', r'sqrt(\1)', s) # 处理平方根
s = s.replace(r'\left', '').replace(r'\right', '') # 移除 \left 和 \right
s = re.sub(r'\\mathrm|\\mathbf|\\text|\\displaystyle|\\normalfont|\\rm', '', s) # 移除样式命令
s = s.replace("\\", "") # 去除残余的反斜杠
s = s.replace("≥", ">=").replace("≤", "<=") # 标准化符号
return s.strip()
def extract_latex(text: str) -> list:
"""提取文本中的所有 LaTeX 公式"""
matches = re.findall(r'\$(.*?)\$', text)
return matches
def latex_match(A: str, B: str) -> bool:
"""检查 A 是否匹配 B 中的任意公式"""
# 提取并标准化 B 中的所有公式
B_formulas = extract_latex(B)
B_norm_formulas = [normalize_latex_string(f) for f in B_formulas]
# 标准化 A
A_norm = normalize_latex_string(A)
# 检查 A 是否匹配任意标准化的公式
for formula in B_norm_formulas:
if A_norm == formula:
return True
return False
if __name__ == "__main__":
text_B = r"\frac { \sqrt { 8 } } { \sqrt { 2 a } } = \frac { \sqrt { 8 } \cdot \sqrt { 2 a } } { \sqrt { 2 a } \cdot \sqrt { 2 a } } = \frac { 4 \sqrt { a } } { 2 a } = \frac { 2 \sqrt { a } } { a }"
text_A = r"'(3)$\\frac { \\sqrt { 8 } } { \\sqrt { 2 a } } = \\frac { \\sqrt { 8 } \\cdot \\sqrt { 2 a } } { \\sqrt { 2 a } \\cdot \\sqrt { 2 a } } = \\frac { 4 \\sqrt { a } } { 2 a } = \\frac { 2 \\sqrt { a } } { a } .$"
print(latex_match(text_B, text_A)) # 应该返回 True
import os
from datetime import datetime
import langid
import json
import base64
import re
def save_logs_to_file(logs_info, file_name=None):
if file_name==None:
return False
with open(file_name, 'w', encoding='utf-8') as file:
json.dump(logs_info, file, ensure_ascii=False, indent=4)
return True
def clean_markdown(markdown_text):
"""
清理Markdown文本,移除所有图像链接、HTML注释及其他无用信息,仅保留正文内容。
参数:
markdown_text (str): 输入的Markdown格式文本。
返回:
str: 清理后的正文内容。
"""
# 删除HTML <img> 标签
markdown_text = re.sub(r'<img[^>]*>', '', markdown_text)
# 删除Markdown ![]() 语法的图片
markdown_text = re.sub(r'!\[[^\]]*\]\([^)]*\)', '', markdown_text)
# 删除HTML注释 <!-- ... -->
markdown_text = re.sub(r'<!--.*?-->', '', markdown_text, flags=re.DOTALL)
# # 删除多余的空行(可选)
# markdown_text = re.sub(r'\n\s*\n', '\n\n', markdown_text)
# 去除行尾多余的空白字符
markdown_text = re.sub(r'[ \t]+$', '', markdown_text, flags=re.MULTILINE)
return markdown_text.strip()
import re
def contains_formula(markdown_info):
"""
全面覆盖公式模式的函数:判断 Markdown 文本中是否包含公式。
参数:
markdown_info (str): Markdown 文本。
返回:
bool: 如果包含公式返回 True,否则返回 False。
"""
# 全面公式检测正则模式
formula_patterns = [
r'\$.*?\$', # 行内公式,例如 $a = b + c$
r'\$\$.*?\$\$', # 块级公式,例如 $$a = b + c$$
r'\\\(.+?\\\)', # 行内公式,例如 \(a = b + c\)
r'\\\[.+?\\\]', # 块级公式,例如 \[a = b + c\]
r'\\frac\s*{.*?}\s*{.*?}', # LaTeX 分式,例如 \frac{a}{b}
r'[a-zA-Z]\s*_\s*{.*?}', # 下标公式,例如 a_{n}
r'[a-zA-Z]\s*\^\s*{.*?}', # 上标公式,例如 a^{2}
r'[a-zA-Z]\s*=\s*.+', # 常见赋值公式,例如 a = b + c
r'\\int\s*.*', # 积分公式,例如 \int_{a}^{b} x^2 dx
r'\\sum\s*.*', # 求和公式,例如 \sum_{i=1}^n i
r'\\prod\s*.*', # 求积公式,例如 \prod_{i=1}^n i
r'\\lim\s*.*', # 极限公式,例如 \lim_{x \to 0} f(x)
r'\\sqrt\s*{.*?}', # 平方根,例如 \sqrt{x}
r'\\log\s*.*', # 对数公式,例如 \log_{2}x
r'\\sin|\\cos|\\tan|\\arcsin|\\arccos|\\arctan', # 三角函数
r'\\left.*?\\right', # 带括号的公式,例如 \left( x+y \right)
r'\\mathrm\s*{.*?}', # LaTeX 数学字体控制,例如 \mathrm{sin}
r'\\overline\s*{.*?}', # 上划线,例如 \overline{AB}
r'\\underline\s*{.*?}', # 下划线,例如 \underline{AB}
r'\\cdot|\\times|\\div|\\pm', # 运算符,例如 \cdot, \times, \div, \pm
r'[a-zA-Z]\s*\\to\s*.*', # 函数关系,例如 x \to y
r'[Α-Ωα-ω]', # 希腊字母,例如 Φ, θ
r'[a-zA-Z]+\s*=\s*[a-zA-Z0-9]+\s*(cos|sin|tan|log|exp|sqrt)\s*[a-zA-Z0-9θ]*', # 带函数符号的公式
]
# 组合所有正则进行检测
for pattern in formula_patterns:
if re.search(pattern, markdown_info, re.DOTALL):
return True
return False
# def contains_formula(markdown_info):
# """
# 全面覆盖公式模式的函数:判断 Markdown 文本中是否包含公式。
# 参数:
# markdown_info (str): Markdown 文本。
# 返回:
# bool: 如果包含公式返回 True,否则返回 False。
# """
# # 全面公式检测正则模式
# formula_patterns = [
# r'\$.*?\$', # 行内公式,例如 $a = b + c$
# r'\$\$.*?\$\$', # 块级公式,例如 $$a = b + c$$
# r'\\\(.+?\\\)', # 行内公式,例如 \(a = b + c\)
# r'\\\[.+?\\\]', # 块级公式,例如 \[a = b + c\]
# r'\\frac\s*{.*?}\s*{.*?}', # LaTeX 分式,例如 \frac{a}{b}
# r'[a-zA-Z]\s*_\s*{.*?}', # 下标公式,例如 a_{n}
# r'[a-zA-Z]\s*\^\s*{.*?}', # 上标公式,例如 a^{2}
# r'[a-zA-Z]\s*=\s*.+', # 常见赋值公式,例如 a = b + c
# r'\\int\s*.*', # 积分公式,例如 \int_{a}^{b} x^2 dx
# r'\\sum\s*.*', # 求和公式,例如 \sum_{i=1}^n i
# r'\\prod\s*.*', # 求积公式,例如 \prod_{i=1}^n i
# r'\\lim\s*.*', # 极限公式,例如 \lim_{x \to 0} f(x)
# r'\\sqrt\s*{.*?}', # 平方根,例如 \sqrt{x}
# r'\\log\s*.*', # 对数公式,例如 \log_{2}x
# r'\\sin|\\cos|\\tan|\\arcsin|\\arccos|\\arctan', # 三角函数
# r'\\left.*?\\right', # 带括号的公式,例如 \left( x+y \right)
# r'\\mathrm\s*{.*?}', # LaTeX 数学字体控制,例如 \mathrm{sin}
# r'\\overline\s*{.*?}', # 上划线,例如 \overline{AB}
# r'\\underline\s*{.*?}', # 下划线,例如 \underline{AB}
# r'\\cdot|\\times|\\div|\\pm', # 运算符,例如 \cdot, \times, \div, \pm
# r'[a-zA-Z]\s*\\to\s*.*', # 函数关系,例如 x \to y
# ]
# # 组合所有正则进行检测
# for pattern in formula_patterns:
# if re.search(pattern, markdown_info, re.DOTALL):
# return True
# return False
def mkdir_if_not_exist(path):
if not os.path.exists(path):
os.makedirs(path)
# 将日志信息保存为 JSON 文件
def save_logs_to_file(logs_info, file_name=None):
if file_name==None:
return False
with open(file_name, 'w', encoding='utf-8') as file:
json.dump(logs_info, file, ensure_ascii=False, indent=4)
return True
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
# 获取毫秒级时间
def get_millisecond_time():
current_time = datetime.now()
time_str = current_time.strftime("%Y%m%d%H%M%S%f")[:-3]
return time_str
def get_lang(text):
lang_detect, _ = langid.classify(text.replace('。', ' ').replace(',', ' ')) # 语言检测
return 'en' if lang_detect == 'en' else 'zh'
def get_day_time():
# 获取当前日期和时间
now = datetime.now()
# 格式化日期和时间为字符串,格式为 "YYYYMMDD_HHMMSS"
formatted_time = now.strftime("%Y%m%d_%H%M%S")
return formatted_time
def merge_ver_boxes(formula_positions):
def sort_by_y_min(box):
return box[2][1]
formula_positions.sort(key=sort_by_y_min)
merged_boxes = []
current_box = None
for box in formula_positions:
category, confidence, bbox = box
x_min, y_min, x_max, y_max = bbox
if current_box is None:
current_box = [category, confidence, bbox]
else:
curr_x_min, curr_y_min, curr_x_max, curr_y_max = current_box[2]
if y_min <= curr_y_max:
# Merge the boxes
merged_bbox = [
min(curr_x_min, x_min),
min(curr_y_min, y_min),
max(curr_x_max, x_max),
max(curr_y_max, y_max)
]
current_box[2] = merged_bbox
else:
# Append the current box to merged_boxes list and start a new current_box
merged_boxes.append(current_box)
current_box = [category, confidence, bbox]
if current_box is not None:
merged_boxes.append(current_box)
return merged_boxes
def is_break(text):
# 定义关键词列表
keywords = [
'答','选择','填空','选择题','填空题','分析','求解','计算','化简','证明', '解', '例题', '名师经验谈', '解析', '答案','练习', 'A.','B.','C.','D.','(1)','(2)','(3)','(4)','(5)','(6)','(7)',
]
# 检查关键词是否存在于文本中
for keyword in keywords:
if keyword in text:
return False # 如果发现关键词,返回 False
return True # 如果所有关键词都不存在,返回 True
import re
# def check_numbers_in_string(text):
# """
# 检查字符串中的数字,如果包含大于10的数字返回False,否则返回True。
# :param text: 输入的字符串
# :return: 布尔值
# """
# # 提取字符串中的所有数字
# numbers = re.findall(r'\d+', text)
# print(numbers)
# # 转换为整数类型
# int_numbers = [int(num) for num in numbers]
# # 检查是否存在大于10的数字
# for num in int_numbers:
# if num >= 5:
# return False
# return True
def check_numbers_in_string(text):
"""
检查字符串中的数字,如果包含大于或等于5的数字返回False,否则返回True。
:param text: 输入的字符串
:return: 布尔值
"""
# 提取字符串中的所有数字,包括中间可能有空格的情况
numbers = re.findall(r'\d(?:\s*\d)*', text) # 匹配数字,允许中间有空格
#print("提取的原始数字:", numbers)
# 去除空格并转换为整数类型
int_numbers = [int(num.replace(' ', '')) for num in numbers]
#print("转换为整数后的数字:", int_numbers)
# 检查是否存在大于或等于5的数字
for num in int_numbers:
if num >= 5:
return False
return True
def process_text(text):
"""
去除文本中的所有空行,并按照换行符进行分割。
:param text: 输入的多行字符串
:return: 去除空行后按 \n 分割的列表
"""
# 去除空行
non_empty_lines = [line for line in text.splitlines() if line.strip()]
return non_empty_lines
def Scan_Content_Aggregation(all_page_details):
all_results_markdown=[]
all_lines_info=process_text(all_page_details)
#all_page_details=[[{},{}],[{},{}]]
#过滤内容
print(all_lines_info)
for line in all_lines_info:
if not is_break(line):
continue
if not check_numbers_in_string(line):
continue
all_results_markdown.append(line)
return '\n'.join(all_results_markdown)
class Singleton(type):
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super().__call__(*args, **kwargs)
return cls._instances[cls]
if __name__=="__main__":
original_text = "\\frac { \\sqrt { a } } { \\sqrt { b } } = \\sqrt { \\frac { a } { b } } ( a \\geq 0 , b > 0 )"
cleaned_text = check_numbers_in_string(original_text)
print(repr(cleaned_text))
import re
import json
from loguru import logger
def calculate_iou(inner_box, outer_box):
# 提取内部边界框的四个顶点
x1, y1 = inner_box[0]
x2, y2 = inner_box[1]
x3, y3 = inner_box[2]
x4, y4 = inner_box[3]
# 计算内部边界框的最小和最大坐标
x_min_inner = min(x1, x2, x3, x4)
y_min_inner = min(y1, y2, y3, y4)
x_max_inner = max(x1, x2, x3, x4)
y_max_inner = max(y1, y2, y3, y4)
# 提取外部边界框的坐标
x_min_outer, y_min_outer, x_max_outer, y_max_outer = outer_box
# 计算交集的坐标
x_min_inter = max(x_min_inner, x_min_outer)
y_min_inter = max(y_min_inner, y_min_outer)
x_max_inter = min(x_max_inner, x_max_outer)
y_max_inter = min(y_max_inner, y_max_outer)
# 计算交集的宽度和高度
inter_width = max(0, x_max_inter - x_min_inter)
inter_height = max(0, y_max_inter - y_min_inter)
# 计算交集面积
inter_area = inter_width * inter_height
# 计算两个边界框的面积
inner_area = (x_max_inner - x_min_inner) * (y_max_inner - y_min_inner)
outer_area = (x_max_outer - x_min_outer) * (y_max_outer - y_min_outer)
# 计算并集面积
union_area = inner_area + outer_area - inter_area
# 计算IoU
iou = inter_area / union_area if union_area != 0 else 0
return iou
def has_intersection(inner_box, outer_box, threshold=0.1):
"""
判断inner_box是否与outer_box有交集(IOU > 0.1)。
参数:
inner_box (list): 内部边界框,格式为[[x1, y1], [x2, y2], [x3, y3], [x4, y4]]
outer_box (list): 外部边界框,格式为[x_min, y_min, x_max, y_max]
threshold (float): 判断交集的IOU阈值,默认为0.1
返回:
bool: 如果inner_box与outer_box有交集(IOU > threshold)则返回True,否则返回False。
"""
iou = calculate_iou(inner_box, outer_box)
return iou > 0
# def perform_re_check(text):
# """
# 检查文本中是否包含公式。
# 参数:
# text (str): 要检查的文本
# 返回:
# bool: 如果包含公式则返回True,否则返回False。
# """
# formula_pattern = re.compile(
# r"([A-Za-z]+\s*=\s*[A-Za-z0-9+\-*/^()]+)|" # 一般公式,包含多个运算符或字母
# r"(\b√[A-Za-z0-9]+\b)|" # 根号
# r"(\bΔ\b)|" # Δ
# r"(\([A-Za-z0-9+\-*/^()]+\)\s*[+\-*/^]\s*\([A-Za-z0-9+\-*/^()]+\))|" # 复杂括号表达式
# r"([A-Za-z]*\d*[(x)(y)(z)(a)(b)(c)]*\s*[+\-*/^]+\s*\(?[A-Za-z0-9+\-*/^()]+\)?)" # 公式中的符号运算
# )
# return bool(formula_pattern.search(text))
def perform_re_check(text):
# 规则1:过滤掉包含仅字母、空格或中文字符的文本
if re.fullmatch(r'[\u4e00-\u9fa5a-zA-Z\s]+', text):
return False
# 规则2:过滤掉不带“=”符号的text
if '=' not in text:
return False
# 规则3:过滤掉表达式和赋值,例如a=5
if re.fullmatch(r'[a-zA-Z]+\s*=\s*\d+', text):
return False
return True
def filter_boxes(ocr_result, layout_result):
"""
过滤排版检测结果,保留下包含公式的边界框。
参数:
ocr_result (dict): OCR识别结果
layout_result (dict): 布局检测结果
返回:
list: 过滤后的排版检测结果
"""
logger.info('in filter_boxes!')
#layout_data = json.loads(json.loads(layout_result['data'])['Ids_Scores_boxes'])
layout_data=layout_result
filtered_layout_boxes = []
filtered_layout_ocrs=[]
for layout_box in layout_data:
layout_coordinates = layout_box[2]
combined_text = ""
for ocr_box in ocr_result['data']:
ocr_coordinates = ocr_box[0]
ocr_text = ocr_box[1][0]
if has_intersection(ocr_coordinates, layout_coordinates):
combined_text += ocr_text + " "
if perform_re_check(combined_text.strip()):
filtered_layout_boxes.append(layout_box)
filtered_layout_ocrs.append(combined_text)
return filtered_layout_boxes,filtered_layout_ocrs
# # 示例使用
# ocr_result = {'errorCode': 0, 'msg': '识别成功', 'data': [[[[96.0, 44.0], [569.0, 44.0], [569.0, 64.0], [96.0, 64.0]], ['1.4.1用空间向量研究直线、平面的位置关系(第3课时)', 0.9963806867599487]], [[[238.0, 97.0], [426.0, 97.0], [426.0, 116.0], [238.0, 116.0]], ['空间中直线、平面的垂直', 0.9977184534072876]], [[[73.0, 145.0], [180.0, 145.0], [180.0, 170.0], [73.0, 170.0]], ['知识清单', 0.998611330986023]], [[[105.0, 187.0], [231.0, 187.0], [231.0, 204.0], [105.0, 204.0]], ['1.直线和直线垂直', 0.9975678324699402]], [[[103.0, 213.0], [544.0, 214.0], [544.0, 234.0], [103.0, 233.0]], ['设直线,l的方向向量分别为u,μ,则uuuu=3.', 0.8799490332603455]], [[[104.0, 243.0], [231.0, 243.0], [231.0, 260.0], [104.0, 260.0]], ['2.直线和平面垂直', 0.9996089935302734]], [[[105.0, 269.0], [591.0, 269.0], [591.0, 286.0], [105.0, 286.0]], ['设u是直线l的方向向量,n是平面α的法向量,lα,则lLαu/nA', 0.8951645493507385]], [[[71.0, 296.0], [177.0, 298.0], [177.0, 316.0], [71.0, 314.0]], ['eR,使得ux入n.', 0.871402382850647]], [[[105.0, 325.0], [267.0, 325.0], [267.0, 341.0], [105.0, 341.0]], ['要点3平面和平面垂直', 0.9992749094963074]], [[[103.0, 349.0], [530.0, 351.0], [530.0, 371.0], [103.0, 369.0]], ['设n,n分别是平面α,β的法向量,则α⊥βn,⊥nn*n=1.', 0.8930713534355164]], [[[73.0, 397.0], [179.0, 397.0], [179.0, 421.0], [73.0, 421.0]], ['例题讲评', 0.9989429712295532]], [[[113.0, 441.0], [592.0, 441.0], [592.0, 461.0], [113.0, 461.0]], ['例1如图,已知正三棱柱ABC-A,B,C,的各棱长都为1,M是底面上', 0.9656786918640137]], [[[72.0, 474.0], [309.0, 474.0], [309.0, 490.0], [72.0, 490.0]], ['BC边的中点,V是侧棱CC上的点,', 0.991436243057251]], [[[102.0, 508.0], [162.0, 504.0], [163.0, 525.0], [104.0, 528.0]], ['且CV=', 0.9438493251800537]], [[[175.0, 507.0], [203.0, 510.0], [201.0, 526.0], [174.0, 524.0]], ['CC,', 0.7943589091300964]], [[[161.0, 522.0], [174.0, 522.0], [174.0, 531.0], [161.0, 531.0]], ['4', 0.9967143535614014]], [[[104.0, 544.0], [234.0, 544.0], [234.0, 560.0], [104.0, 560.0]], ['(1)求证:ABLMN;', 0.8688576817512512]], [[[103.0, 568.0], [343.0, 569.0], [343.0, 589.0], [103.0, 588.0]], ['(2)设CC,中点为D,求证:AB,⊥A,D.', 0.8998482823371887]], [[[106.0, 736.0], [589.0, 738.0], [589.0, 758.0], [106.0, 756.0]], ['练习1如图,△ABC和△BCD所在平面互相垂直,且AB=BC=BD=', 0.9863734245300293]], [[[72.0, 773.0], [519.0, 773.0], [519.0, 789.0], [72.0, 789.0]], ['2,ZABC=ZDBC=120°,E,F分别为AC,DC的中点.求证:EF⊥BC.', 0.9168810844421387]]]}
# layout_result = {'errorCode': 0, 'msg': '识别成功', 'data': '{"Ids_Scores_boxes": "[[[0], 0.5164221525192261, [72.51856621809566, 187.46995484912938, 593.7710048745755, 369.08707769273684]], [[10], 0.4314861297607422, [74.0, 741.0, 589.0, 788.0]], [[1], 0.8855363726615906, [74.0858920888629, 143.8997421042558, 179.57114012112666, 169.70436416413207]], [[1], 0.8849098086357117, [73.6531480480851, 396.0868356459496, 178.90487691988264, 420.96475408517415]], [[1], 0.8246437907218933, [94.7679979762152, 44.78538083893527, 567.6477659821933, 64.29004998079591]], [[1], 0.0, [241.0, 100.0, 425.0, 114.0]], [[0], 0.0, [74.0, 443.0, 592.0, 587.0]]]", "boxes_num": "7"}'}
# filtered_layout_boxes = filter_boxes(ocr_result, layout_result)
# print(filtered_layout_boxes)
import re
import json
def calculate_iou(inner_box, outer_box):
# 提取内部边界框的四个顶点
x1, y1 = inner_box[0]
x2, y2 = inner_box[1]
x3, y3 = inner_box[2]
x4, y4 = inner_box[3]
# 计算内部边界框的最小和最大坐标
x_min_inner = min(x1, x2, x3, x4)
y_min_inner = min(y1, y2, y3, y4)
x_max_inner = max(x1, x2, x3, x4)
y_max_inner = max(y1, y2, y3, y4)
# 提取外部边界框的坐标
x_min_outer, y_min_outer, x_max_outer, y_max_outer = outer_box
# 计算交集的坐标
x_min_inter = max(x_min_inner, x_min_outer)
y_min_inter = max(y_min_inner, y_min_outer)
x_max_inter = min(x_max_inner, x_max_outer)
y_max_inter = min(y_max_inner, y_max_outer)
# 计算交集的宽度和高度
inter_width = max(0, x_max_inter - x_min_inter)
inter_height = max(0, y_max_inter - y_min_inter)
# 计算交集面积
inter_area = inter_width * inter_height
# 计算两个边界框的面积
inner_area = (x_max_inner - x_min_inner) * (y_max_inner - y_min_inner)
outer_area = (x_max_outer - x_min_outer) * (y_max_outer - y_min_outer)
# 计算并集面积
union_area = inner_area + outer_area - inter_area
# 计算IoU
iou = inter_area / union_area if union_area != 0 else 0
return iou
def has_intersection(inner_box, outer_box, threshold=0.1):
"""
判断inner_box是否与outer_box有交集(IOU > 0.1)。
参数:
inner_box (list): 内部边界框,格式为[[x1, y1], [x2, y2], [x3, y3], [x4, y4]]
outer_box (list): 外部边界框,格式为[x_min, y_min, x_max, y_max]
threshold (float): 判断交集的IOU阈值,默认为0.1
返回:
bool: 如果inner_box与outer_box有交集(IOU > threshold)则返回True,否则返回False。
"""
iou = calculate_iou(inner_box, outer_box)
return iou > 0
# def perform_re_check(text):
# """
# 检查文本中是否包含公式。
# 参数:
# text (str): 要检查的文本
# 返回:
# bool: 如果包含公式则返回True,否则返回False。
# """
# formula_pattern = re.compile(
# r"([A-Za-z]+\s*=\s*[A-Za-z0-9+\-*/^()]+)|" # 一般公式,包含多个运算符或字母
# r"(\b√[A-Za-z0-9]+\b)|" # 根号
# r"(\bΔ\b)|" # Δ
# r"(\([A-Za-z0-9+\-*/^()]+\)\s*[+\-*/^]\s*\([A-Za-z0-9+\-*/^()]+\))|" # 复杂括号表达式
# r"([A-Za-z]*\d*[(x)(y)(z)(a)(b)(c)]*\s*[+\-*/^]+\s*\(?[A-Za-z0-9+\-*/^()]+\)?)" # 公式中的符号运算
# )
# return bool(formula_pattern.search(text))
def perform_re_check(text):
# 规则1:过滤掉包含仅字母、空格或中文字符的文本
if re.fullmatch(r'[\u4e00-\u9fa5a-zA-Z\s]+', text):
return False
# 规则2:过滤掉不带“=”符号的text
if '=' not in text:
return False
# 规则3:过滤掉表达式和赋值,例如a=5
if re.fullmatch(r'[a-zA-Z]+\s*=\s*\d+', text):
return False
return True
def filter_boxes(ocr_result, layout_result):
"""
过滤排版检测结果,保留下包含公式的边界框。
参数:
ocr_result (dict): OCR识别结果
layout_result (dict): 布局检测结果
返回:
list: 过滤后的排版检测结果
"""
layout_data = json.loads(json.loads(layout_result['data'])['Ids_Scores_boxes'])
filtered_layout_boxes = []
filtered_layout_ocrs=[]
for layout_box in layout_data:
layout_coordinates = layout_box[2]
combined_text = ""
for ocr_box in ocr_result['data']:
ocr_coordinates = ocr_box[0]
ocr_text = ocr_box[1][0]
if has_intersection(ocr_coordinates, layout_coordinates):
combined_text += ocr_text + " "
#print(combined_text)
if perform_re_check(combined_text.strip()):
filtered_layout_boxes.append(layout_box)
filtered_layout_ocrs.append(combined_text)
return filtered_layout_boxes,filtered_layout_ocrs
# # # # 示例使用
# ocr_result = {'errorCode': 0, 'msg': '识别成功', 'data': [[[[132.0, 6.0], [487.0, 6.0], [487.0, 23.0], [132.0, 23.0]], ['1.4.2用空间向量研究距离、夹角问题(一)', 0.9909250140190125]], [[[274.0, 57.0], [348.0, 57.0], [348.0, 77.0], [274.0, 77.0]], ['空间距离', 0.9974817037582397]], [[[50.0, 105.0], [158.0, 105.0], [158.0, 130.0], [50.0, 130.0]], ['知识清单', 0.9985876083374023]], [[[83.0, 148.0], [209.0, 148.0], [209.0, 165.0], [83.0, 165.0]], ['1.点到直线的距离', 0.999904215335846]], [[[84.0, 176.0], [395.0, 176.0], [395.0, 192.0], [84.0, 192.0]], ["已知直线l的方向向量是a,点P#l,P'el,则点", 0.8676444292068481]], [[[83.0, 223.0], [236.0, 223.0], [236.0, 240.0], [83.0, 240.0]], ['P到直线l的距离为d:', 0.932934045791626]], [[[270.0, 227.0], [289.0, 227.0], [289.0, 234.0], [270.0, 234.0]], ['DP', 0.7328389883041382]], [[[83.0, 269.0], [438.0, 269.0], [438.0, 286.0], [83.0, 286.0]], ['两条平行直线间的距离可以转化为点到直线的距离,', 0.9892963767051697]], [[[82.0, 297.0], [209.0, 297.0], [209.0, 313.0], [82.0, 313.0]], ['2.点到平面的距离', 0.9998923540115356]], [[[84.0, 323.0], [567.0, 323.0], [567.0, 340.0], [84.0, 340.0]], ['已知AB为平面α的一条斜线段(点A在平面α内),n为平面α的法向量,', 0.9887251853942871]], [[[47.0, 358.0], [392.0, 360.0], [392.0, 387.0], [47.0, 385.0]], ['则点B到平面α的距离为d=AB|·cos<AB,n)', 0.9637517929077148]], [[[378.0, 351.0], [438.0, 351.0], [438.0, 372.0], [378.0, 372.0]], ['無·n', 0.6339988112449646]], [[[447.0, 365.0], [570.0, 365.0], [570.0, 382.0], [447.0, 382.0]], ['空间中其他距离', 0.9999088644981384]], [[[49.0, 408.0], [85.0, 408.0], [85.0, 427.0], [49.0, 427.0]], ['问题', 0.9955520629882812]], [[[86.0, 435.0], [296.0, 435.0], [296.0, 452.0], [86.0, 452.0]], ['一般都可以转化为点面距问题.', 0.9974074363708496]], [[[49.0, 479.0], [159.0, 479.0], [159.0, 507.0], [49.0, 507.0]], ['例题讲评', 0.9977942109107971]], [[[91.0, 525.0], [571.0, 525.0], [571.0, 545.0], [91.0, 545.0]], ['例1如图,在棱长为2的正方体ABCD-A,B,C,D,中,E是BC的中点,P', 0.9382723569869995]], [[[52.0, 558.0], [286.0, 558.0], [286.0, 575.0], [52.0, 575.0]], ['是AE上的动点,求DP的最小值', 0.9763669371604919]], [[[84.0, 725.0], [570.0, 727.0], [570.0, 748.0], [84.0, 746.0]], ['练习1在长方体0ABC-0,A,B,C,中,0A=2,AB=3,AA,=2,求0,到', 0.887532114982605]], [[[50.0, 759.0], [159.0, 759.0], [159.0, 779.0], [50.0, 779.0]], ['直线AC的距离.', 0.9977055788040161]]]}
# layout_result = {'errorCode': 0, 'msg': '识别成功', 'data': '{"Ids_Scores_boxes": "[[[1], 0.6768088340759277, [132.82876458235634, 4.867646808112585, 484.25709724895194, 24.110981528087148]], [[1], 0.5721949338912964, [50.41460480158925, 478.7880785462551, 157.3041640894006, 504.33530555645126]], [[1], 0.699893593788147, [53.0, 526.0, 570.0, 576.0]], [[10, 0], 0.6270818710327148, [52.00979196153577, 761.0801100416344, 155.13669618897046, 776.1116098266145]], [[10], 0.0, [86.0, 727.0, 569.0, 747.0]], [[1], 0.8837159276008606, [51.657900767122186, 103.97743590987875, 157.41005498988994, 129.50588956440203]], [[0], 0.0, [50.0, 150.0, 569.0, 451.0]], [[1], 0.0, [276.0, 58.0, 347.0, 76.0]]]", "boxes_num": "8"}'}
# #print(ocr_result['data'][0])
# filtered_layout_boxes,filtered_layout_ocrs = filter_boxes(ocr_result, layout_result)
# print(filtered_layout_ocrs)
# print(filtered_layout_boxes)
# # 测试用例
import cv2
import base64
import numpy as np
def get_base64_from_image(image):
"""
将图像转换为Base64编码。
:param image: 输入图像的NumPy数组
:return: Base64编码的字符串
"""
_, buffer = cv2.imencode('.jpg', image)
base64_str = base64.b64encode(buffer).decode('utf-8')
return base64_str
def draw_bounding_boxes(image_path, coordinates_list):
"""
在图像上绘制多个红色边框并返回带框图像的Base64编码。
:param image_path: 图像路径
:param coordinates_list: 每个元素为以8个数字表示的四个角坐标,顺序为左上,右上,右下,左下
:return: 图像Base64编码的字符串
"""
image = cv2.imread(image_path)
if image is None:
raise ValueError(f"Failed to read the image file '{image_path}'.")
for coordinates in coordinates_list:
# 解析输入坐标
x1, y1, x2, y2, x3, y3, x4, y4 = map(int, coordinates)
# 计算矩形边界
x_min = min(x1, x2, x3, x4)
y_min = min(y1, y2, y3, y4)
x_max = max(x1, x2, x3, x4)
y_max = max(y1, y2, y3, y4)
# 确保边界不超出图像范围
x_min = max(0, x_min)
y_min = max(0, y_min)
x_max = min(image.shape[1], x_max)
y_max = min(image.shape[0], y_max)
# 绘制红色矩形框 (B, G, R) -> (0, 0, 255)
cv2.rectangle(image, (x_min, y_min), (x_max, y_max), (0, 0, 255), 2)
# 转换为Base64编码
return get_base64_from_image(image)
def get_formula_boundingbox_base64_list(return_sub_img_path_results):
"""
根据输入路径和坐标信息,返回每张图像包含红色边框的Base64编码列表。
:param return_sub_img_path_results: 字典,键为图像路径,值为包含边界框的坐标列表
:return: Base64编码的列表
"""
base64_list = []
for image_path, coordinates_list in return_sub_img_path_results.items():
# 在图像上绘制红色框并转换为Base64编码
base64_image = draw_bounding_boxes(image_path, coordinates_list)
base64_list.append(base64_image)
return base64_list
# 示例调用
if __name__ == "__main__":
return_sub_img_path_results = {
"example_image.jpg": [
[100, 50, 200, 50, 200, 150, 100, 150],
[300, 200, 400, 200, 400, 300, 300, 300]
]
}
base64_lists = get_formula_boundingbox_base64_list(return_sub_img_path_results)
for base64_str in base64_lists:
print(base64_str)
\ No newline at end of file
import requests
import time
def ocr_service_request(image_url):
"""
向指定的 OCR 服务发送请求,并返回响应信息。
参数:
- image_url: 图像的 URL 路径
返回:
- response_json: 成功时返回的响应内容(JSON 格式)
- elapsed_time: 请求和响应之间的时间(秒)
- status_code: 请求的 HTTP 状态码
- error_message: 请求失败时的错误消息
"""
# 请求数据
data_info = {
"url": image_url,
}
try:
# 记录请求开始的时间
start_time = time.time()
# 发送 POST 请求
response = requests.post('http://localhost:8880/v1/got_ocr_markdown_local', json=data_info)
# 记录请求结束的时间
end_time = time.time()
# 计算请求和响应之间的时间
elapsed_time = end_time - start_time
# 返回请求相关信息
if response.status_code == 200:
return response.json(), elapsed_time, response.status_code, None
else:
return None, elapsed_time, response.status_code, response.text
except requests.exceptions.RequestException as e:
return None, None, None, str(e)
# 调用示例
if __name__ == "__main__":
image_url = '/data/wangtengbo/got_ocr2/infer/demo.png' # 替换为需要处理的图片路径
response_json, elapsed_time, status_code, error_message = ocr_service_request(image_url)
if status_code == 200:
print(f"Request successful! Time taken: {elapsed_time:.2f} seconds")
print("Response content:", response_json)
else:
print(f"Request failed with status code {status_code}. Error: {error_message}")
import os
from datetime import datetime
import langid
import numpy as np
from sklearn.cluster import KMeans
from collections import Counter
from loguru import logger
def mkdir_if_not_exist(path):
if not os.path.exists(path):
os.makedirs(path)
# 获取毫秒级时间
def get_millisecond_time():
current_time = datetime.now()
time_str = current_time.strftime("%Y%m%d%H%M%S%f")[:-3]
return time_str
def get_lang(text):
lang_detect, _ = langid.classify(text.replace('。', ' ').replace(',', ' ')) # 语言检测
return 'en' if lang_detect == 'en' else 'zh'
def get_day_time():
# 获取当前日期和时间
now = datetime.now()
# 格式化日期和时间为字符串,格式为 "YYYYMMDD_HH%M%S"
formatted_time = now.strftime("%Y%m%d_%H%M%S")
return formatted_time
def merge_ver_boxes(formula_positions):
def sort_by_y_min(box):
return box[2][1]
formula_positions.sort(key=sort_by_y_min)
merged_boxes = []
current_box = None
for box in formula_positions:
category, confidence, bbox = box
x_min, y_min, x_max, y_max = bbox
if current_box is None:
current_box = [category, confidence, bbox]
else:
curr_x_min, curr_y_min, curr_x_max, curr_y_max = current_box[2]
if y_min <= curr_y_max:
# Merge the boxes
merged_bbox = [
min(curr_x_min, x_min),
min(curr_y_min, y_min),
max(curr_x_max, x_max),
max(curr_y_max, y_max)
]
current_box[2] = merged_bbox
else:
# Append the current box to merged_boxes list and start a new current_box
merged_boxes.append(current_box)
current_box = [category, confidence, bbox]
if current_box is not None:
merged_boxes.append(current_box)
return merged_boxes
def merge_hor_boxes(formula_positions):
"""
Merges horizontally overlapping or adjacent bounding boxes.
Args:
formula_positions (list): A list of bounding box data where each item is
in the format (category, confidence, [x_min, y_min, x_max, y_max]).
Returns:
list: A list of merged bounding box data in the same format.
"""
def sort_by_x_min(box):
return box[2][0] # Sort by x_min
# Sort the boxes by their x_min coordinate
formula_positions.sort(key=sort_by_x_min)
merged_boxes = []
current_box = None
for box in formula_positions:
category, confidence, bbox = box
x_min, y_min, x_max, y_max = bbox
if current_box is None:
current_box = [category, confidence, bbox]
else:
curr_x_min, curr_y_min, curr_x_max, curr_y_max = current_box[2]
if x_min <= curr_x_max: # Check if boxes overlap or are adjacent horizontally
# Merge the boxes
merged_bbox = [
min(curr_x_min, x_min),
min(curr_y_min, y_min),
max(curr_x_max, x_max),
max(curr_y_max, y_max)
]
current_box[2] = merged_bbox
else:
# Append the current box to merged_boxes list and start a new current_box
merged_boxes.append(current_box)
current_box = [category, confidence, bbox]
if current_box is not None:
merged_boxes.append(current_box)
return merged_boxes
def merge_boxes_by_clustering(formula_positions, target_num_boxes=4):
# Extract bounding box centers as features for clustering
box_centers = np.array([((bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2) for _, _, bbox in formula_positions])
# Initialize KMeans with the target number of clusters (boxes)
kmeans = KMeans(n_clusters=target_num_boxes, random_state=0).fit(box_centers)
# Assign cluster labels to each bounding box
labels = kmeans.labels_
# Merge bounding boxes based on cluster labels
merged_boxes = []
for i in range(target_num_boxes):
boxes_in_cluster = [formula_positions[j] for j in range(len(formula_positions)) if labels[j] == i]
# Calculate merged bounding box for each cluster
min_x = min([bbox[0] for _, _, bbox in boxes_in_cluster])
min_y = min([bbox[1] for _, _, bbox in boxes_in_cluster])
max_x = max([bbox[2] for _, _, bbox in boxes_in_cluster])
max_y = max([bbox[3] for _, _, bbox in boxes_in_cluster])
merged_bbox = [min_x, min_y, max_x, max_y]
merged_boxes.append((None, None, merged_bbox)) # Replace None with category and confidence if needed
return merged_boxes
def merge_horizontal_boxes(formula_positions, num_big_boxes):
"""
将给定的公式框从上到下合并成指定数量的大框,并保证这些大框之间没有垂直重叠。
参数:
formula_positions: List[Tuple[List[int], float, List[float]]]
每个公式的位置数据结构为 (category_list, confidence, [x_min, y_min, x_max, y_max])
num_big_boxes: int
期望最终合并得到的大框数量
返回:
List[Tuple[List[int], float, List[float]]]
返回合并后的大框列表,结构与输入相同。
"""
if not formula_positions:
return []
# 按 y_min 排序
formula_positions.sort(key=lambda box: box[2][1])
total = len(formula_positions)
group_size = (total + num_big_boxes - 1) // num_big_boxes # 向上取整分组
merged_boxes = []
for i in range(num_big_boxes):
start_idx = i * group_size
end_idx = (i + 1) * group_size
group = formula_positions[start_idx:end_idx]
if not group:
continue
# 合并本组框
category = group[0][0]
confidence = group[0][1]
x_min, y_min, x_max, y_max = group[0][2]
for j in range(1, len(group)):
c, conf, bbox = group[j]
# 更新置信度,如取最大值
if conf > confidence:
confidence = conf
gx_min, gy_min, gx_max, gy_max = bbox
x_min = min(x_min, gx_min)
y_min = min(y_min, gy_min)
x_max = max(x_max, gx_max)
y_max = max(y_max, gy_max)
merged_boxes.append([category, confidence, [x_min, y_min, x_max, y_max]])
# 确保大框之间无重叠
# 假设merged_boxes已按y_min排序(因为我们在分组时就是按排序后的顺序合并的)
for i in range(1, len(merged_boxes)):
prev_box = merged_boxes[i - 1]
curr_box = merged_boxes[i]
_, _, [prev_x_min, prev_y_min, prev_x_max, prev_y_max] = prev_box
_, _, [curr_x_min, curr_y_min, curr_x_max, curr_y_max] = curr_box
# 如果当前框的y_min <= 上一个框的y_max,说明有重叠,需要调整
if curr_y_min <= prev_y_max:
# 将当前框向下平移,使得curr_y_min = prev_y_max + 1
shift = (prev_y_max + 1) - curr_y_min
curr_y_min += shift
curr_y_max += shift
# 更新当前框的坐标
curr_box[2] = [curr_x_min, curr_y_min, curr_x_max, curr_y_max]
return merged_boxes
def process_formula_positions(formula_positions, target_num_boxes=2):
"""
处理公式位置,合并垂直方向上的框和聚类框,并返回最终的边界框。
参数:
formula_positions: List[Tuple[List[int], float, List[float]]]
每个公式的位置,由类别、置信度和边界框组成。
target_num_boxes: int
目标聚类框的数量。
返回:
List[List[float]]: 合并后的公式边界框列表。
"""
# print(formula_positions)
# hor_merges=merge_horizontal_boxes(formula_positions)
# logger.info(f'hor_merges={hor_merges}')
#合并垂直方向上的框
merged_boxes_ver = merge_ver_boxes(formula_positions)
#print(len(merged_boxes_ver))
merged_boxes_hor=merge_horizontal_boxes(merged_boxes_ver,target_num_boxes)
#print(merged_boxes_hor)
#print(len(merged_boxes_hor))
#merged_boxes_hor = merge_hor_boxes(formula_positions)
# if len(formula_positions) < target_num_boxes:
# target_num_boxes = len(formula_positions)
# # 使用聚类算法合并框,仅保留 target_num_boxes 个框
# merged_boxes = merge_boxes_by_clustering(merged_boxes_hor, target_num_boxes=target_num_boxes)
# 提取合并后的边界框信息
formula_boxes = [data[2] for data in merged_boxes_hor]
return formula_boxes
import numpy as np
def merge_bounding_boxes(formula_positions):
# 转换为NumPy数组
positions_array = np.array(formula_positions)
# 找到最左、最上、最右、最下的点
min_x1 = np.min(positions_array[:, 0])
min_y1 = np.min(positions_array[:, 1])
max_x2 = np.max(positions_array[:, 2])
max_y2 = np.max(positions_array[:, 3])
# 返回新的大的边界框
return [min_x1, min_y1, max_x2, max_y2]
if __name__ == "__main__":
# 示例
formula_positions = [
[1, 2, 3, 4],
[2, 3, 5, 6],
[0, 1, 4, 5]
]
merged_box = merge_bounding_boxes(formula_positions)
print(merged_box) # 输出:[0, 1, 5, 6]
import os
from datetime import datetime
import langid
import numpy as np
from sklearn.cluster import KMeans
def mkdir_if_not_exist(path):
"""
如果目录不存在,则创建目录。
参数:
path (str): 要创建的目录路径。
"""
if not os.path.exists(path):
os.makedirs(path)
def get_millisecond_time():
"""
获取当前时间(精确到毫秒)。
返回:
str: 格式化的当前时间字符串(精确到毫秒)。
"""
current_time = datetime.now()
time_str = current_time.strftime("%Y%m%d%H%M%S%f")[:-3]
return time_str
def get_lang(text):
"""
检测给定文本的语言。
参数:
text (str): 要检测语言的文本。
返回:
str: 'en' 表示英文,'zh' 表示中文。
"""
lang_detect, _ = langid.classify(text.replace('。', ' ').replace(',', ' '))
return 'en' if lang_detect == 'en' else 'zh'
def get_day_time():
"""
获取当前日期和时间的格式化字符串。
返回:
str: 格式化的当前日期和时间字符串(格式为 "YYYYMMDD_HH%M%S")。
"""
now = datetime.now()
formatted_time = now.strftime("%Y%m%d_%H%M%S")
return formatted_time
def merge_ver_boxes(formula_positions):
"""
合并垂直方向上重叠的边界框。
参数:
formula_positions (list): 包含类别、置信度和边界框的元组列表。
返回:
list: 合并后的边界框。
"""
def sort_by_y_min(box):
return box[2][1]
formula_positions.sort(key=sort_by_y_min)
merged_boxes = []
current_box = None
for box in formula_positions:
category, confidence, bbox = box
x_min, y_min, x_max, y_max = bbox
if current_box is None:
current_box = [category, confidence, bbox]
else:
curr_x_min, curr_y_min, curr_x_max, curr_y_max = current_box[2]
if y_min <= curr_y_max:
merged_bbox = [
min(curr_x_min, x_min),
min(curr_y_min, y_min),
max(curr_x_max, x_max),
max(curr_y_max, y_max)
]
current_box[2] = merged_bbox
else:
merged_boxes.append(current_box)
current_box = [category, confidence, bbox]
if current_box is not None:
merged_boxes.append(current_box)
return merged_boxes
def merge_hor_boxes(formula_positions):
"""
合并水平方向上重叠的边界框。
参数:
formula_positions (list): 包含类别、置信度和边界框的元组列表。
返回:
list: 合并后的边界框。
"""
def sort_by_x_min(box):
return box[2][0]
formula_positions.sort(key=sort_by_x_min)
merged_boxes = []
current_box = None
for box in formula_positions:
category, confidence, bbox = box
x_min, y_min, x_max, y_max = bbox
if current_box is None:
current_box = [category, confidence, bbox]
else:
curr_x_min, curr_y_min, curr_x_max, curr_y_max = current_box[2]
if x_min <= curr_x_max:
merged_bbox = [
min(curr_x_min, x_min),
min(curr_y_min, y_min),
max(curr_x_max, x_max),
max(curr_y_max, y_max)
]
current_box[2] = merged_bbox
else:
merged_boxes.append(current_box)
current_box = [category, confidence, bbox]
if current_box is not None:
merged_boxes.append(current_box)
return merged_boxes
def merge_boxes_by_clustering(formula_positions, target_num_boxes=4):
"""
使用聚类算法合并边界框。
参数:
formula_positions (list): 包含类别、置信度和边界框的元组列表。
target_num_boxes (int): 目标合并成的边界框数量。
返回:
list: 基于聚类的合并边界框。
"""
box_centers = np.array([((bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2) for _, _, bbox in formula_positions])
kmeans = KMeans(n_clusters=target_num_boxes, random_state=0).fit(box_centers)
labels = kmeans.labels_
merged_boxes = []
for i in range(target_num_boxes):
boxes_in_cluster = [formula_positions[j] for j in range(len(formula_positions)) if labels[j] == i]
if not boxes_in_cluster:
continue
min_x = min([bbox[0] for _, _, bbox in boxes_in_cluster])
min_y = min([bbox[1] for _, _, bbox in boxes_in_cluster])
max_x = max([bbox[2] for _, _, bbox in boxes_in_cluster])
max_y = max([bbox[3] for _, _, bbox in boxes_in_cluster])
merged_bbox = [min_x, min_y, max_x, max_y]
merged_boxes.append((None, None, merged_bbox))
return merged_boxes
def process_formula_positionsv2(formula_positions, target_num_boxes=4):
"""
通过合并垂直、水平和聚类框来处理公式位置。
参数:
formula_positions (list): 包含类别、置信度和边界框的元组列表。
target_num_boxes (int): 目标合并成的边界框数量。
返回:
list: 最终合并的边界框列表。
"""
merged_boxes_ver = merge_ver_boxes(formula_positions)
#merged_boxes_hor = merge_hor_boxes(merged_boxes_ver)
# print(merged_boxes_ver)
# print(merged_boxes_hor)
if len(merged_boxes_ver) < target_num_boxes:
target_num_boxes = len(merged_boxes_ver)
merged_boxes = merge_boxes_by_clustering(merged_boxes_ver, target_num_boxes=target_num_boxes)
formula_boxes = [data[2] for data in merged_boxes]
return formula_boxes
def merge_bounding_boxes(formula_positions):
"""
将多个边界框合并为一个大的边界框。
参数:
formula_positions (list): 边界框列表。
返回:
list: 合并后的大边界框。
"""
positions_array = np.array(formula_positions)
min_x1 = np.min(positions_array[:, 0])
min_y1 = np.min(positions_array[:, 1])
max_x2 = np.max(positions_array[:, 2])
max_y2 = np.max(positions_array[:, 3])
return [min_x1, min_y1, max_x2, max_y2]
if __name__ == "__main__":
formula_positions = [[[6], 0.0, [1609.0, 199.0, 2127.0, 329.0]], [[0], 0.9736799597740173, [225.93232349219818, 1611.367297878681, 2017.0720675821328, 2075.073957172845]], [[0], 0.9499262571334839, [1579.0, 2211.0, 1992.0, 2841.0]], [[1], 0.4914904832839966, [320.0, 2088.0, 739.0, 2161.0]], [[1], 0.4311352074146271, [228.9372284589486, 1553.0413806681088, 726.2493805957434, 1587.8428093167895]], [[1], 0.0, [765.0, 972.0, 842.0, 1011.0]], [[1], 0.0, [778.0, 1329.0, 848.0, 1368.0]], [[1], 0.0, [431.0, 717.0, 634.0, 766.0]], [[1], 0.0, [327.0, 1152.0, 554.0, 1195.0]], [[1], 0.0, [351.0, 1460.0, 735.0, 1506.0]], [[1], 0.0, [1012.0, 517.0, 1352.0, 566.0]], [[0], 0.0, [1059.0, 713.0, 1867.0, 874.0]], [[1], 0.0, [1062.0, 930.0, 1696.0, 962.0]], [[1], 0.0, [1055.0, 1024.0, 1249.0, 1093.0]], [[1], 0.0, [1059.0, 1149.0, 1840.0, 1182.0]], [[1], 0.0, [1095.0, 1247.0, 1980.0, 1345.0]], [[1], 0.0, [1092.0, 1413.0, 1346.0, 1486.0]], [[7], 0.4297367036342621, [1976.0, 2958.0, 2067.0, 2994.0]], [[0], 0.0, [227.0, 2235.0, 1453.0, 2854.0]]]
#formula_positions=[data[2] for data in formula_positions]
merged_box = process_formula_positions(formula_positions)
print(merged_box) # 输出:[0, 1, 5, 6]
import os
from datetime import datetime
import langid
import numpy as np
from sklearn.cluster import KMeans
def mkdir_if_not_exist(path):
"""
如果目录不存在,则创建目录。
参数:
path (str): 要创建的目录路径。
"""
if not os.path.exists(path):
os.makedirs(path)
def get_millisecond_time():
"""
获取当前时间(精确到毫秒)。
返回:
str: 格式化的当前时间字符串(精确到毫秒)。
"""
current_time = datetime.now()
time_str = current_time.strftime("%Y%m%d%H%M%S%f")[:-3]
return time_str
def get_lang(text):
"""
检测给定文本的语言。
参数:
text (str): 要检测语言的文本。
返回:
str: 'en' 表示英文,'zh' 表示中文。
"""
lang_detect, _ = langid.classify(text.replace('。', ' ').replace(',', ' '))
return 'en' if lang_detect == 'en' else 'zh'
def get_day_time():
"""
获取当前日期和时间的格式化字符串。
返回:
str: 格式化的当前日期和时间字符串(格式为 "YYYYMMDD_HH%M%S")。
"""
now = datetime.now()
formatted_time = now.strftime("%Y%m%d_%H%M%S")
return formatted_time
def merge_ver_boxes(formula_positions):
"""
合并垂直方向上重叠的边界框。
参数:
formula_positions (list): 包含类别、置信度和边界框的元组列表。
返回:
list: 合并后的边界框。
"""
def sort_by_y_min(box):
return box[2][1]
formula_positions.sort(key=sort_by_y_min)
merged_boxes = []
current_box = None
for box in formula_positions:
category, confidence, bbox = box
x_min, y_min, x_max, y_max = bbox
if current_box is None:
current_box = [category, confidence, bbox]
else:
curr_x_min, curr_y_min, curr_x_max, curr_y_max = current_box[2]
if y_min <= curr_y_max:
merged_bbox = [
min(curr_x_min, x_min),
min(curr_y_min, y_min),
max(curr_x_max, x_max),
max(curr_y_max, y_max)
]
current_box[2] = merged_bbox
else:
merged_boxes.append(current_box)
current_box = [category, confidence, bbox]
if current_box is not None:
merged_boxes.append(current_box)
return merged_boxes
def merge_hor_boxes(formula_positions):
"""
合并水平方向上重叠的边界框。
参数:
formula_positions (list): 包含类别、置信度和边界框的元组列表。
返回:
list: 合并后的边界框。
"""
def sort_by_x_min(box):
return box[2][0]
formula_positions.sort(key=sort_by_x_min)
merged_boxes = []
current_box = None
for box in formula_positions:
category, confidence, bbox = box
x_min, y_min, x_max, y_max = bbox
if current_box is None:
current_box = [category, confidence, bbox]
else:
curr_x_min, curr_y_min, curr_x_max, curr_y_max = current_box[2]
if x_min <= curr_x_max:
merged_bbox = [
min(curr_x_min, x_min),
min(curr_y_min, y_min),
max(curr_x_max, x_max),
max(curr_y_max, y_max)
]
current_box[2] = merged_bbox
else:
merged_boxes.append(current_box)
current_box = [category, confidence, bbox]
if current_box is not None:
merged_boxes.append(current_box)
return merged_boxes
def merge_boxes_by_clustering(formula_positions, target_num_boxes=4):
"""
使用聚类算法合并边界框。
参数:
formula_positions (list): 包含类别、置信度和边界框的元组列表。
target_num_boxes (int): 目标合并成的边界框数量。
返回:
list: 基于聚类的合并边界框。
"""
box_centers = np.array([((bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2) for _, _, bbox in formula_positions])
kmeans = KMeans(n_clusters=target_num_boxes, random_state=0).fit(box_centers)
labels = kmeans.labels_
merged_boxes = []
for i in range(target_num_boxes):
boxes_in_cluster = [formula_positions[j] for j in range(len(formula_positions)) if labels[j] == i]
if not boxes_in_cluster:
continue
min_x = min([bbox[0] for _, _, bbox in boxes_in_cluster])
min_y = min([bbox[1] for _, _, bbox in boxes_in_cluster])
max_x = max([bbox[2] for _, _, bbox in boxes_in_cluster])
max_y = max([bbox[3] for _, _, bbox in boxes_in_cluster])
merged_bbox = [min_x, min_y, max_x, max_y]
merged_boxes.append((None, None, merged_bbox))
return merged_boxes
def process_formula_positions_v2(formula_positions, target_num_boxes=5):
"""
通过合并垂直、水平和聚类框来处理公式位置。
参数:
formula_positions (list): 包含类别、置信度和边界框的元组列表。
target_num_boxes (int): 目标合并成的边界框数量。
返回:
list: 最终合并的边界框列表。
"""
#merged_boxes_ver = merge_ver_boxes(formula_positions)
merged_boxes_hor = merge_hor_boxes(formula_positions)
# print(merged_boxes_ver)
# print(merged_boxes_hor)
if len(merged_boxes_hor) < target_num_boxes:
target_num_boxes = len(merged_boxes_hor)
merged_boxes = merge_boxes_by_clustering(merged_boxes_hor, target_num_boxes=target_num_boxes)
formula_boxes = [data[2] for data in merged_boxes]
return formula_boxes
def merge_bounding_boxes(formula_positions):
"""
将多个边界框合并为一个大的边界框。
参数:
formula_positions (list): 边界框列表。
返回:
list: 合并后的大边界框。
"""
positions_array = np.array(formula_positions)
min_x1 = np.min(positions_array[:, 0])
min_y1 = np.min(positions_array[:, 1])
max_x2 = np.max(positions_array[:, 2])
max_y2 = np.max(positions_array[:, 3])
return [min_x1, min_y1, max_x2, max_y2]
if __name__ == "__main__":
formula_positions = [[[6], 0.0, [1609.0, 199.0, 2127.0, 329.0]], [[0], 0.9736799597740173, [225.93232349219818, 1611.367297878681, 2017.0720675821328, 2075.073957172845]], [[0], 0.9499262571334839, [1579.0, 2211.0, 1992.0, 2841.0]], [[1], 0.4914904832839966, [320.0, 2088.0, 739.0, 2161.0]], [[1], 0.4311352074146271, [228.9372284589486, 1553.0413806681088, 726.2493805957434, 1587.8428093167895]], [[1], 0.0, [765.0, 972.0, 842.0, 1011.0]], [[1], 0.0, [778.0, 1329.0, 848.0, 1368.0]], [[1], 0.0, [431.0, 717.0, 634.0, 766.0]], [[1], 0.0, [327.0, 1152.0, 554.0, 1195.0]], [[1], 0.0, [351.0, 1460.0, 735.0, 1506.0]], [[1], 0.0, [1012.0, 517.0, 1352.0, 566.0]], [[0], 0.0, [1059.0, 713.0, 1867.0, 874.0]], [[1], 0.0, [1062.0, 930.0, 1696.0, 962.0]], [[1], 0.0, [1055.0, 1024.0, 1249.0, 1093.0]], [[1], 0.0, [1059.0, 1149.0, 1840.0, 1182.0]], [[1], 0.0, [1095.0, 1247.0, 1980.0, 1345.0]], [[1], 0.0, [1092.0, 1413.0, 1346.0, 1486.0]], [[7], 0.4297367036342621, [1976.0, 2958.0, 2067.0, 2994.0]], [[0], 0.0, [227.0, 2235.0, 1453.0, 2854.0]]]
#formula_positions=[data[2] for data in formula_positions]
merged_box = process_formula_positions_v2(formula_positions)
print(merged_box) # 输出:[0, 1, 5, 6]
import re
def find_and_highlight_substring(A, B):
# 正则表达式匹配B
pattern = re.escape(B)
# 存储匹配到的内容的起始和结束下标
matches = []
# 处理单个$符号内的内容
def single_dollar_replacement(match):
start = match.start(1)
end = match.end(1)
inner_text = match.group(1)
highlighted_text = re.sub(pattern, r'<span style="color:red">\g<0></span>', inner_text)
matches.append((start, end))
return f"${highlighted_text}$"
# 处理双个$$符号内的内容
def double_dollar_replacement(match):
start = match.start(1)
end = match.end(1)
inner_text = match.group(1)
highlighted_text = re.sub(pattern, r'<span style="color:red">\g<0></span>', inner_text)
matches.append((start, end))
return f"$${highlighted_text}$$"
# 使用正则表达式替换并标记匹配位置
highlighted_A = re.sub(r'\$(.*?)\$', single_dollar_replacement, A)
highlighted_A = re.sub(r'\$\$(.*?)\$\$', double_dollar_replacement, highlighted_A)
return highlighted_A, matches
# # 示例
# A = '# 1.4.1 用空间向量研究直线、平面的位置关系(第3课时)\n\n## 空间中直线、平面的垂直\n\n### 知识清单\n\n1. 直线和直线垂直\n设直线 $l_1, l_2$ 的方向向量分别为 $$u_1, u_2$$,则 $$l_1 \perp l_2 \Leftrightarrow u_1 \cdot u_2 = 0$$。'
# B = 'l_1 \perp l_2'
# highlighted_A, matches = find_and_highlight_substring(A, B)
# print("Highlighted A:\n", highlighted_A)
# print("Matches:", matches)
import xml.etree.ElementTree as ET
import cv2
import random
def parse_xml(file_path):
"""
解析XML文件,提取公式框和文本框的坐标。
参数:
file_path (str): XML文件路径。
返回:
tuple: 公式框列表和文本框列表,格式为[(xmin, ymin, xmax, ymax), ...]。
"""
tree = ET.parse(file_path)
root = tree.getroot()
formula_boxes = []
text_boxes = []
for obj in root.findall('object'):
name = obj.find('name').text
xmin = int(obj.find('bndbox/xmin').text)
ymin = int(obj.find('bndbox/ymin').text)
xmax = int(obj.find('bndbox/xmax').text)
ymax = int(obj.find('bndbox/ymax').text)
box = (xmin, ymin, xmax, ymax)
if name == 'formula':
formula_boxes.append(box)
elif name == 'text':
text_boxes.append(box)
return formula_boxes, text_boxes
def find_closest_text_boxes(formula_box, text_boxes):
"""
找到给定公式框最近的左、右、上、下文本框。
参数:
formula_box (tuple): 公式框的坐标(xmin, ymin, xmax, ymax)。
text_boxes (list): 文本框列表,格式为[(xmin, ymin, xmax, ymax), ...]。
返回:
tuple: 最近的左、右、上、下文本框的坐标,如果没有则为None。
"""
left_text = right_text = top_text = bottom_text = None
left_dist = right_dist = top_dist = bottom_dist = float('inf')
xmin_f, ymin_f, xmax_f, ymax_f = formula_box
for text_box in text_boxes:
xmin_t, ymin_t, xmax_t, ymax_t = text_box
if check_overlap(formula_box, text_box):
continue
if xmax_t < xmin_f and xmin_f - xmax_t < left_dist:
left_dist = xmin_f - xmax_t
left_text = text_box
if xmin_t > xmax_f and xmin_t - xmax_f < right_dist:
right_dist = xmin_t - xmax_f
right_text = text_box
if ymax_t < ymin_f and ymin_f - ymax_t < top_dist:
top_dist = ymin_f - ymax_t
top_text = text_box
if ymin_t > ymax_f and ymin_t - ymax_f < bottom_dist:
bottom_dist = ymin_t - ymax_f
bottom_text = text_box
return left_text, right_text, top_text, bottom_text
def calculate_iou(box1, box2):
"""
计算两个框的交并比(IoU)。
参数:
box1 (tuple): 第一个框的坐标(xmin, ymin, xmax, ymax)。
box2 (tuple): 第二个框的坐标(xmin, ymin, xmax, ymax)。
返回:
float: 两个框的交并比(IoU)。
"""
xmin1, ymin1, xmax1, ymax1 = box1
xmin2, ymin2, xmax2, ymax2 = box2
# 计算交集坐标
inter_xmin = max(xmin1, xmin2)
inter_ymin = max(ymin1, ymin2)
inter_xmax = min(xmax1, xmax2)
inter_ymax = min(ymax1, ymax2)
# 计算交集面积
inter_area = max(0, inter_xmax - inter_xmin + 1) * max(0, inter_ymax - inter_ymin + 1)
# 计算每个框的面积
box1_area = (xmax1 - xmin1 + 1) * (ymax1 - ymin1 + 1)
box2_area = (xmax2 - xmin2 + 1) * (ymax2 - ymax2 + 1)
# 计算并集面积
union_area = box1_area + box2_area - inter_area
# 计算IoU
iou = inter_area / union_area
return iou
def check_overlap(box1, box2, iou_threshold=0.1):
"""
检查两个框是否有重叠,并且重叠的IoU在一定范围内是可以接受的。
参数:
box1 (tuple): 第一个框的坐标(xmin, ymin, xmax, ymax)。
box2 (tuple): 第二个框的坐标(xmin, ymin, xmax, ymax)。
iou_threshold (float): 可接受的IoU阈值。
返回:
bool: 如果两个框的IoU在可接受范围内则返回True,否则返回False。
"""
iou = calculate_iou(box1, box2)
return iou >= iou_threshold
def check_no_overlap_with_others(expanded_box, formula_boxes, current_box):
"""
检查扩展后的框是否与其他公式框重叠。
参数:
expanded_box (tuple): 扩展后的框的坐标(xmin, ymin, xmax, ymax)。
formula_boxes (list): 所有公式框的列表。
current_box (tuple): 当前正在处理的公式框。
返回:
bool: 如果扩展后的框与其他公式框没有重叠则返回True,否则返回False。
"""
for other_box in formula_boxes:
if other_box != current_box and check_overlap(expanded_box, other_box):
return False
return True
def expand_formula_boxes(formula_boxes, text_boxes):
"""
扩展公式框,使其与最近的文本框融合。
参数:
formula_boxes (list): 公式框列表,格式为[(xmin, ymin, xmax, ymax), ...]。
text_boxes (list): 文本框列表,格式为[(xmin, ymin, xmax, ymax), ...]。
返回:
list: 扩展后的公式框列表。
"""
expanded_formula_boxes = []
for formula_box in formula_boxes:
xmin_f, ymin_f, xmax_f, ymax_f = formula_box
left_text, right_text, top_text, bottom_text = find_closest_text_boxes(formula_box, text_boxes)
# 扩展左边
if left_text:
xmin_t, ymin_t, xmax_t, ymax_t = left_text
new_xmin_f = xmin_t
expanded_box = (new_xmin_f, min(ymin_f, ymin_t), xmax_f, max(ymax_f, ymax_t))
if check_no_overlap_with_others(expanded_box, formula_boxes, formula_box):
xmin_f = new_xmin_f
ymin_f = min(ymin_f, ymin_t)
ymax_f = max(ymax_f, ymax_t)
# 扩展右边
if right_text:
xmin_t, ymin_t, xmax_t, ymax_t = right_text
new_xmax_f = xmax_t
expanded_box = (xmin_f, min(ymin_f, ymin_t), new_xmax_f, max(ymax_f, ymax_t))
if check_no_overlap_with_others(expanded_box, formula_boxes, formula_box):
xmax_f = new_xmax_f
ymin_f = min(ymin_f, ymin_t)
ymax_f = max(ymax_f, ymax_t)
# 扩展上边
if top_text:
xmin_t, ymin_t, xmax_t, ymax_t = top_text
new_ymin_f = ymin_t
expanded_box = (min(xmin_f, xmin_t), new_ymin_f, max(xmax_f, xmax_t), ymax_f)
if check_no_overlap_with_others(expanded_box, formula_boxes, formula_box):
ymin_f = new_ymin_f
xmin_f = min(xmin_f, xmin_t)
xmax_f = max(xmax_f, xmax_t)
# 扩展下边
if bottom_text:
xmin_t, ymin_t, xmax_t, ymax_t = bottom_text
new_ymax_f = ymax_t
expanded_box = (min(xmin_f, xmin_t), ymin_f, max(xmax_f, xmax_t), new_ymax_f)
if check_no_overlap_with_others(expanded_box, formula_boxes, formula_box):
ymax_f = new_ymax_f
xmin_f = min(xmin_f, xmin_t)
xmax_f = max(xmax_f, xmax_t)
expanded_formula_boxes.append((xmin_f, ymin_f, xmax_f, ymax_f))
return expanded_formula_boxes
def draw_boxes_on_image(image_path, boxes, output_path):
"""
在图像上绘制框并保存结果图像。
参数:
image_path (str): 输入图像的路径。
boxes (list): 需要绘制的框的列表,格式为[(xmin, ymin, xmax, ymax), ...]。
output_path (str): 保存结果图像的路径。
"""
image = cv2.imread(image_path)
def random_color():
return (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
for box in boxes:
xmin, ymin, xmax, ymax = box
color = random_color()
cv2.rectangle(image, (xmin, ymin), (xmax, ymax), color, 2)
cv2.imwrite(output_path, image)
print(f"扩展后的图像已保存到 {output_path}")
def main(xml_file_path, image_file_path, output_image_path):
"""
主函数,执行XML解析、公式框扩展和绘制框的操作。
参数:
xml_file_path (str): XML文件的路径。
image_file_path (str): 输入图像的路径。
output_image_path (str): 保存结果图像的路径。
"""
formula_boxes, text_boxes = parse_xml(xml_file_path)
expanded_formula_boxes = expand_formula_boxes(formula_boxes, text_boxes)
for box in expanded_formula_boxes:
print(box)
draw_boxes_on_image(image_file_path, expanded_formula_boxes, output_image_path)
if __name__ == "__main__":
main('1.xml', '1.png', './expanded_image.png')
import requests
import os
import mimetypes
from typing import Dict, Optional, Union, Tuple
from urllib.parse import quote
class OBSUploader:
def __init__(self, base_url: str = "https://open.5rs.me", auth_token: Optional[str] = None):
"""
Initialize the OBS uploader.
Args:
base_url: The base URL for the API
auth_token: The authorization token for API access
"""
self.base_url = base_url.rstrip('/')
self.auth_token = auth_token
self.headers = {
'Authorization': f'Bearer {auth_token}' if auth_token else None
}
# Initialize mimetypes
mimetypes.init()
def _get_content_type(self, file_path: Union[str, bytes]) -> Tuple[str, bytes]:
"""
Get content type and file content from file path or bytes.
Args:
file_path: Path to the file or file content as bytes
Returns:
Tuple of (content_type, file_content)
"""
if isinstance(file_path, str):
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
content_type, _ = mimetypes.guess_type(file_path)
with open(file_path, 'rb') as f:
file_content = f.read()
else:
file_content = file_path
# For bytes input, try to detect type from first few bytes
content_type = 'application/octet-stream' # Default content type
return content_type or 'application/octet-stream', file_content
def get_upload_url(self, biz_code: str, object_name: str, content_type: str) -> Dict:
"""
Get a temporary upload URL for the specified object.
Args:
biz_code: Business code for the upload
object_name: Name/path of the object to upload
content_type: MIME type of the file
Returns:
Dict containing the upload URL and related information
"""
endpoint = f"{self.base_url}/aimodel/v1.0/obs/getCreatePostSignature"
params = {
'bizCode': biz_code,
'objectName': object_name,
'mimeType': content_type
}
response = requests.get(endpoint, params=params, headers=self.headers)
response.raise_for_status()
return response.json()
def upload_file(self, file_path: Union[str, bytes], biz_code: str, object_name: str) -> Dict:
"""
Upload a file using temporary credentials.
Args:
file_path: Path to the file to upload or file content as bytes
biz_code: Business code for the upload
object_name: Name/path of the object to upload
Returns:
Dict containing the upload result and file URL
"""
# Get content type and file content
content_type, file_content = self._get_content_type(file_path)
# Get temporary upload URL with content type
upload_info = self.get_upload_url(biz_code, object_name, content_type)
if upload_info['errCode'] != 0:
raise Exception(f"Failed to get upload URL: {upload_info['message']}")
upload_url = upload_info['data']['temporarySignatureUrl']
# Upload the file with the correct content type
headers = {
'Content-Type': content_type,
'Content-Length': str(len(file_content))
}
response = requests.put(upload_url, data=file_content, headers=headers)
response.raise_for_status()
return {
'success': True,
'file_url': upload_info['data']['domain'] + '/' + object_name,
'object_url_map': upload_info['data']['objectUrlMap']
}
# Example usage:
if __name__ == "__main__":
# Initialize uploader
uploader = OBSUploader(auth_token="dcg-4c1e3a7f4fcd415e8c93151ff539d20a")
# Upload a file
try:
result = uploader.upload_file(
file_path="/data/wangtengbo/formula_node4_生产/logs/2025-06-06/draw_box_sub_images/0a9fb8f899c74d979c7dce58f61ff00e/formula_1.png",
biz_code="formula",
object_name="image/test.jpg"
)
print(result)
print(f"File uploaded successfully! URL: {result['file_url']}")
except Exception as e:
print(f"Upload failed: {str(e)}")
\ No newline at end of file
import matplotlib.pyplot as plt
import matplotlib as mpl
# 设置 LaTeX 渲染
mpl.rcParams['text.usetex'] = True
mpl.rcParams['font.size'] = 12
# LaTeX 公式
latex_formula = r'$a \perp b \Leftrightarrow a \cdot b = 0$'
# 创建一个图形和轴
fig, ax = plt.subplots()
# 隐藏轴
ax.axis('off')
# 显示 LaTeX 公式
ax.text(0.5, 0.5, latex_formula, fontsize=20, ha='center', va='center')
# 保存图形为图像
output_image_path = "latex_formula.png"
plt.savefig(output_image_path, bbox_inches='tight')
# 显示图像
plt.show()
import re
from sympy.parsing.latex import parse_latex
def parse_and_normalize_latex(latex_str):
"""
使用 sympy 解析 LaTeX 并标准化表达式。
:param latex_str: LaTeX 表达式字符串
:return: 标准化后的字符串表示
"""
try:
expr = parse_latex(latex_str)
return str(expr)
except Exception as e:
print(f"Error parsing LaTeX: {e}")
return None
def extract_and_normalize_formulas(text):
"""
从文本中提取公式并标准化。
:param text: 包含公式的文本
:return: 标准化后的公式列表
"""
formulas = []
# 提取 LaTeX 公式
latex_matches = re.findall(r'\\[a-zA-Z]+|\\frac|\\sqrt|[A-Za-z]+=|\\[A-Za-z0-9]+', text)
for match in latex_matches:
normalized = parse_and_normalize_latex(match)
if normalized:
formulas.append(normalized)
return formulas
def check_formula_in_text(target_formula, text):
"""
判断目标公式是否存在于文本中。
:param target_formula: LaTeX 表达式的目标公式
:param text: 包含公式的目标文本
:return: True 如果目标公式在文本中,否则 False
"""
normalized_target_formula = parse_and_normalize_latex(target_formula)
if not normalized_target_formula:
return False
normalized_formulas_in_text = extract_and_normalize_formulas(text)
return normalized_target_formula in normalized_formulas_in_text
def process_text(text):
"""
去除文本中的所有空行,并按照换行符进行分割。
:param text: 输入的多行字符串
:return: 去除空行后按 \n 分割的列表
"""
# 去除空行
non_empty_lines = [r'{}'.format(line) for line in text.splitlines() if line.strip()]
return non_empty_lines
def Scan_Content_Aggregation(all_page_details):
all_lines_info=process_text(all_page_details)
print(all_lines_info)
#all_results_markdown=[line for line in all_lines_info if is_break(line) and check_numbers_in_string(line)]
#all_page_details=[[{},{}],[{},{}]]
#过滤内容
# print(all_lines_info)
# for line in all_lines_info:
# if not is_break(line):
# continue
# if not check_numbers_in_string(line):
# continue
# all_results_markdown.append(line)
return ''
# 测试案例
if __name__ == "__main__":
original_text="""
**(多媒体展示)填空:**
(1)$\sqrt { 4 } \times \sqrt { 9 } =$ ,$\sqrt { 4 \times 9 } =$ ;
(2)$\sqrt { 2 5 } \times \sqrt { 1 6 } =$ ,$\sqrt { 2 5 \times 1 6 } =$ ;
(3)$\sqrt { \frac { 1 } { 9 } } \times \sqrt { 3 6 } =$ ,
$\sqrt { \frac { 1 } { 9 } \times 3 6 } =$ ;
(4)$\sqrt { 1 0 0 } \times \sqrt { 0 } =$ ,$\sqrt { 1 0 0 \times 0 } =$
生:(1)$\sqrt { 4 } \times \sqrt { 9 } = 6$,$\sqrt { 4 \times 9 } = 6$;(2)$\sqrt { 2 5 } \times \sqrt { 1 6 } = 2 0$,$\sqrt { 2 5 \times 1 6 } = 2 0 ; ( 3 ) \sqrt { \frac { 1 } { 9 } }$
$\times \sqrt { 3 6 } = 2$ $\sqrt { \frac { 1 } { 9 } } \times 3 6 = 2 ;$ $; ( 4 ) \sqrt { 1 0 0 } \times \sqrt { 0 } = 0$,$\sqrt { 1 0 0 \times 0 } = 0 .$·
试一试,参考上面的结果,比较各组等式的大小关系.
生:上面各组中两个算式的结果相等.
## 二、新课教授
$\because x > 0, \therefore x = 30 \sqrt{2}.$
【例4】若$\frac { \sqrt { x + 1 } } { \sqrt { x - 1 } } = \sqrt { \frac { x + 1 } { x - 1 } }$成立,求x的取值范围.
分析:等式$\frac { \sqrt { a } } { \sqrt { b } } = \sqrt { \frac { a } { b } }$只有a≥0,b&gt;0时才能成立.
解:由题意,得$\{ \begin{matrix} x + 1 \geq 0 \\ x - 1 > 0 \end{matrix} ,$即$\{ \begin{matrix} x \geq - 1 , \\ x > 1 . \end{matrix}$
∴x&gt;1.
## 四、巩固练习
(2)首先利用$\sqrt { a ^ { 3 } } = \vert a \vert$化简去掉二次根号,再根据x的范围来判断绝对值中的代数式的正负,去掉绝对值符号.
$\vert x - 2 \vert + \sqrt { ( x + 3 ) ^ { 2 } } + \sqrt { x ^ { 2 } - 10 x + 25 }$
=|x-2|+|x+3|+|x-5|·
∵-3≤x≤2,
∴x-2≤0,$\textcircled { 3 }$x+3≥0,,x-5<0.
∴原式=-(x-2)+(x+3)-(x-5)
=-x+2+x+3-x+5
"""
cleaned_text = Scan_Content_Aggregation(original_text)
print(cleaned_text)
import base64
from PIL import Image
from io import BytesIO
import cv2
import numpy as np
import matplotlib.pyplot as plt
def get_base64_from_image(image):
"""
将输入的图像转换为Base64编码的字符串。
:param image: 输入图像(OpenCV格式)
:return: Base64编码的字符串
"""
# 将图像转换为PIL格式
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
# 使用BytesIO将图像保存到内存中的字节流
buffered = BytesIO()
pil_image.save(buffered, format="PNG")
# 获取字节流并进行Base64编码
img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
return img_str
def get_subimage_base64(image_path, coordinates, scale_factor=1):
"""
提取子图并返回其Base64编码的字符串。
:param image_path: 输入图像路径(可以是图像文件或PDF文件)
:param coordinates: 以8个数字表示的四个角坐标,顺序为左上,右上,右下,左下
:param scale_factor: 如果输入为PDF,scale_factor 用于将坐标从72dpi转换为目标图像的像素坐标(默认为1,无需转换)
:return: Base64编码的子图
"""
image = cv2.imread(image_path)
if image is None:
raise ValueError(f"Failed to read the image file '{image_path}'.")
# 获取图像的尺寸
height, width = image.shape[:2]
# 解析输入坐标
x1, y1, x2, y2, x3, y3, x4, y4 = map(int, coordinates)
# 计算裁剪区域的矩形边界
x_min = min(x1, x2, x3, x4)
y_min = min(y1, y2, y3, y4)
x_max = max(x1, x2, x3, x4)
y_max = max(y1, y2, y3, y4)
# 确保裁剪区域不超出图像边界
x_min = max(0, x_min)
y_min = max(0, y_min)
x_max = min(width, x_max)
y_max = min(height, y_max)
# 提取ROI
roi = image[y_min:y_max, x_min:x_max]
# 如果子图为空,抛出异常
if roi.size == 0:
raise ValueError(f"The extracted region is empty for the given coordinates: {coordinates}")
# 将提取的子图转换为Base64编码并返回
return get_base64_from_image(roi)
def visualize_base64_image(base64_str, output_path=None):
"""
可视化Base64编码的图像,并使用cv2存储图像。
:param base64_str: Base64编码的图像字符串
:param output_path: 可选的输出路径,如果提供则使用cv2保存图像到文件
"""
# 解码Base64编码为图像字节流
img_data = base64.b64decode(base64_str)
# 使用BytesIO将字节流读取为PIL图像
img = Image.open(BytesIO(img_data))
# 将PIL图像转换为OpenCV格式(BGR)
open_cv_image = np.array(img)
open_cv_image = open_cv_image[:, :, ::-1] # 从RGB转为BGR
# 使用cv2保存图像,如果指定了输出路径
if output_path:
cv2.imwrite(output_path, open_cv_image)
print(f"Image saved to: {output_path}")
if __name__ == '__main__':
# 示例:图像坐标为 (left, upper, right, lower)
image_path = "/data/wangtengbo/Deployments_Formula_Checker_v5_复现并微调_v7.0-合和OCR-没有版面/app/logs_data/sub_images/20241210142852866/formula_1.png"
coordinates = [74, 143, 1240, 143, 1240, 241, 74, 241]
# 获取子图的Base64编码
base64_subimage = get_subimage_base64(image_path, coordinates)
# 可视化Base64编码的子图
visualize_base64_image(base64_subimage,output_path='/data/wangtengbo/Deployments_Formula_Checker_v5_复现并微调_v7.0-合和OCR-没有版面/app/utils/base64.jpg')
import base64
from PIL import Image
from io import BytesIO
import cv2
import numpy as np
import cv2
import numpy as np
import os
from loguru import logger
def draw_box_and_save(image_path, coordinates, draw_box_sub_img_save_dir,save_img_name,scale_factor=1):
"""
在图像上绘制多边形框并保存回原路径,返回图像路径。
:param image_path: 输入图像路径
:param coordinates: 以8个数字表示的四个角坐标,顺序为左上、右上、右下、左下
:param scale_factor: 可选的缩放因子(若不需要则默认为1)
:return: 保存后的图像路径(即 image_path)
"""
# 读取图像
image = cv2.imread(image_path)
if image is None:
raise ValueError(f"无法读取图像文件: {image_path}")
# 解析坐标并应用缩放
x1, y1, x2, y2, x3, y3, x4, y4 = map(float, coordinates)
x1, y1, x2, y2, x3, y3, x4, y4 = [int(c * scale_factor) for c in (x1, y1, x2, y2, x3, y3, x4, y4)]
# 构造多边形顶点数组
pts = np.array([[x1, y1],
[x2, y2],
[x3, y3],
[x4, y4]], dtype=np.int32).reshape((-1, 1, 2))
# 在图像上绘制红色多边形框
cv2.polylines(image, [pts], isClosed=True, color=(0, 0, 255), thickness=8)
save_image_path=os.path.join(draw_box_sub_img_save_dir,save_img_name+'.png')
# 保存覆盖原图
success = cv2.imwrite(save_image_path, image)
if not success:
logger.info(f'保存图像错误!\n\nsave_image_path={save_image_path}')
raise IOError(f"无法将图像保存到: {save_image_path}")
return save_image_path
def get_base64_from_image(image):
"""
将输入的图像转换为Base64编码的字符串。
:param image: 输入图像(OpenCV格式)
:return: Base64编码的字符串
"""
# 将图像转换为PIL格式
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
# 使用BytesIO将图像保存到内存中的字节流
buffered = BytesIO()
pil_image.save(buffered, format="PNG")
# 获取字节流并进行Base64编码
img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
return img_str
def get_subimage_base64_boxes(image_path, coordinates, scale_factor=1):
"""
提取子图并返回其Base64编码的字符串。
:param image_path: 输入图像路径(可以是图像文件或PDF文件)
:param coordinates: 以8个数字表示的四个角坐标,顺序为左上,右上,右下,左下
:param scale_factor: 如果输入为PDF,scale_factor 用于将坐标从72dpi转换为目标图像的像素坐标(默认为1,无需转换)
:return: Base64编码的子图
"""
image = cv2.imread(image_path)
if image is None:
raise ValueError(f"Failed to read the image file '{image_path}'.")
# 获取图像的尺寸
height, width = image.shape[:2]
# 解析输入坐标
x1, y1, x2, y2, x3, y3, x4, y4 = map(int, coordinates)
# 计算矩形边界的顶点
pts = np.array([[x1, y1], [x2, y2], [x3, y3], [x4, y4]], np.int32)
pts = pts.reshape((-1, 1, 2))
# 在图像上绘制红色边框
image_with_box = image.copy()
cv2.polylines(image_with_box, [pts], isClosed=True, color=(0, 0, 255), thickness=8)
# 将修改后的图像转换为Base64编码并返回
return get_base64_from_image(image_with_box)
def visualize_base64_image(base64_str, output_path=None):
"""
可视化Base64编码的图像,并使用cv2存储图像。
:param base64_str: Base64编码的图像字符串
:param output_path: 可选的输出路径,如果提供则使用cv2保存图像到文件
"""
# 解码Base64编码为图像字节流
img_data = base64.b64decode(base64_str)
# 使用BytesIO将字节流读取为PIL图像
img = Image.open(BytesIO(img_data))
# 将PIL图像转换为OpenCV格式(BGR)
open_cv_image = np.array(img)
open_cv_image = open_cv_image[:, :, ::-1] # 从RGB转为BGR
# 使用cv2保存图像,如果指定了输出路径
if output_path:
cv2.imwrite(output_path, open_cv_image)
print(f"Image saved to: {output_path}")
if __name__ == '__main__':
# 示例:图像坐标为 (left, upper, right, lower)
image_path = "/data/wangtengbo/Deployments_Formula_Checker_v5_复现并微调_v7.0-合和OCR-没有版面/app/logs_data/sub_images/20241214164802726/formula_1.png"
coordinates = [55, 1476, 3659, 1510, 3645, 3070, 41, 3042]
# 获取添加红色边界框的图像的Base64编码
base64_image_with_box = get_subimage_base64_boxes(image_path, coordinates)
# 可视化Base64编码的图像
visualize_base64_image(base64_image_with_box, output_path='/data/wangtengbo/Deployments_Formula_Checker_v5_复现并微调_v7.0-合和OCR-没有版面/app/utils/image_with_box.jpg')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment