# encoding=utf8
import sys
import os
import time
from loguru import logger
import datetime
from sqlalchemy import create_engine, UniqueConstraint
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import declarative_base
from sqlalchemy import Column, String, Integer, DateTime, Text, SmallInteger
from config.config import MYSQL_DB_URL

engine_knowledge = create_engine(
            url=MYSQL_DB_URL,
            max_overflow = 10,  #超过连接池大小外最多创建的连接，为0表示超过5个连接后，其他连接请求会阻塞 （默认为10）
            pool_size = 50,      #连接池大小（默认为5）
            pool_timeout = 30,  #连接线程池中，没有连接时最多等待的时间，不设置无连接时直接报错 （默认为30）
            pool_recycle = 3600   #多久之后对线程池中的线程进行一次连接的回收（重置） （默认为-1）
            )
Session_Knowledge = sessionmaker(bind=engine_knowledge)

BaseKnowledge = declarative_base()

class KnowledgeBasePrompts(BaseKnowledge):
    __tablename__ = 'knowledge-base-prompts'
    id = Column(Integer, primary_key=True)
    task_id = Column(Integer, unique=True)
    task_name = Column(String(50))
    prompt = Column(Text)
    llm = Column(String(30))
    author = Column(String(50))
    status = Column(Integer)
    attr = Column(Text)
    create_time = Column(DateTime, default=datetime.datetime.now)
    update_time = Column(DateTime, default=datetime.datetime.now)

class CorrectLLMStatisticsLog(BaseKnowledge):
    __tablename__ = 'correct_llm_statistics_log'
    id = Column(Integer, primary_key=True)
    env = Column(String(32))
    source_type = Column(String(32))
    server_name = Column(String(64))
    user_id = Column(String(64))
    publish_id = Column(String(64))
    request = Column(String(512))
    model = Column(String(512))
    url = Column(String(512))
    api_key = Column(String(512))
    api_version = Column(String(512))
    input_token = Column(Integer)
    output_token = Column(Integer)
    start_time = Column(DateTime)
    end_time = Column(DateTime)
    message = Column(String(512))
    doc_id = Column(String(64))
    fragment_id = Column(String(512))
    book_name = Column(String(512))
    text = Column(Text)
    backup = Column(Text)
    update_time = Column(DateTime, default=datetime.datetime.now)

"""
类DBUtils的定义
"""

class DBUtils(object):
    def __init__(self):
        pass

    @staticmethod
    def get_prompt(task_id, task_name=None):
        session = Session_Knowledge()
        try:
            prompt_obj = session.query(KnowledgeBasePrompts).filter_by(task_id=task_id).first()
            return prompt_obj
        except Exception as e:
            logger.error("get_prompt error: {0}",e)
            return None
        finally:
            session.close()

    @staticmethod
    def insert_llm_log(env=None, source_type=None, server_name=None, user_id=None, publish_id=None, 
                       request=None, model=None, url=None, api_key=None,api_version=None,
                       input_token=0, output_token=0, start_time=None, end_time=None, 
                       message=None, doc_id=None, fragment_id=None,book_name=None, text=None,
                       backup=None):
        begin_time = datetime.datetime.now()
        session = Session_Knowledge()
        try:
            log = CorrectLLMStatisticsLog(env=env, source_type=source_type, server_name=server_name, 
                                          user_id=user_id, publish_id=publish_id, request=request,
                                          model=model, url=url, api_key=api_key, api_version=api_version,
                                          input_token=input_token, output_token=output_token,
                                          start_time=start_time, end_time=end_time, message=message,
                                          doc_id=doc_id, fragment_id=fragment_id,book_name=book_name, text=text,
                                          backup=backup)
            session.add(log)
            session.commit()
        except Exception as e:
            logger.error("insert_llm_log error: {}".format(e))
            session.rollback()
        finally:
            session.close()
        over_time = datetime.datetime.now()
        logger.info("insert_llm_log elapsed time: {}".format(over_time - begin_time))
    
def test_dbutil():
    pass

if __name__ == '__main__':
    test_dbutil()