import json
import os
import logging
import re
import time
import threading
import numpy as np
from typing import List, Dict, Any, Optional
from chroma_client import chroma_client
from embedding_service import embedding_service
try:
    from reranker_client import reranker_client
    RERANKER_SERVICE_AVAILABLE = True
except ImportError:
    RERANKER_SERVICE_AVAILABLE = False
    pass
try:
    from rank_bm25 import BM25Okapi
    BM25_AVAILABLE = True
except ImportError:
    BM25_AVAILABLE = False
    logger = logging.getLogger(__name__)
    logger.warning("rank_bm25 не установлен. Установите: pip install rank-bm25")
try:
    from sentence_transformers import CrossEncoder
    CROSSENCODER_AVAILABLE = True
except ImportError:
    CROSSENCODER_AVAILABLE = False
    logger = logging.getLogger(__name__)
    logger.warning("sentence-transformers не установлен. Установите: pip install sentence-transformers")
try:
    from optimum.onnxruntime import ORTModelForSequenceClassification
    from transformers import AutoTokenizer
    import onnxruntime
    ONNX_AVAILABLE = True
except ImportError:
    ONNX_AVAILABLE = False
    logger = logging.getLogger(__name__)
    logger.warning("ONNX не доступен, будет использован обычный CrossEncoder")
logger = logging.getLogger(__name__)
class BM25Index:

    def __init__(self):
        self.bm25 = None
        self.documents = []
        self.document_metadata = []
        self.document_ids = []
        self._needs_rebuild = True
        self._tokenized_docs = []
    def _tokenize(self, text: str) -> List[str]:

        text_lower = text.lower()
        tokens = re.findall(r'\b\w{2,}\b', text_lower)
        return tokens
    def build_index(self, documents: List[Dict[str, Any]]):

        if not BM25_AVAILABLE:
            logger.warning("BM25 недоступен, индекс не создан")
            return
        try:
            self.documents = []
            self.document_metadata = []
            self.document_ids = []
            self._tokenized_docs = []
            for doc in documents:
                content = doc.get('content', '')
                if content and content.strip():
                    self.documents.append(content)
                    self.document_metadata.append(doc.get('metadata', {}))
                    self.document_ids.append(doc.get('id', ''))
                    tokens = self._tokenize(content)
                    self._tokenized_docs.append(tokens)
            if self._tokenized_docs:
                self.bm25 = BM25Okapi(self._tokenized_docs)
                self._needs_rebuild = False
                logger.info(f"BM25 индекс создан из {len(self.documents)} документов")
            else:
                logger.warning("Нет документов для создания BM25 индекса")
        except Exception as e:
            logger.error(f"Ошибка создания BM25 индекса: {e}")
            self.bm25 = None
            self._needs_rebuild = True
    def search(self, query: str, top_k: int = 10, where_clause: Optional[Dict] = None) -> List[Dict[str, Any]]:

        if not BM25_AVAILABLE or not self.bm25:
            return []
        try:
            if not self.documents or not self.document_metadata or not self.document_ids:
                logger.warning("BM25 индекс пуст, поиск невозможен")
                return []
            query_tokens = self._tokenize(query)
            if not query_tokens:
                return []
            scores = self.bm25.get_scores(query_tokens)
            if len(scores) != len(self.documents):
                logger.warning(f"Несоответствие длин: scores={len(scores)}, documents={len(self.documents)}")
                return []
            results = []
            for i, score in enumerate(scores):
                if i >= len(self.documents) or i >= len(self.document_metadata) or i >= len(self.document_ids):
                    logger.warning(f"Индекс {i} выходит за границы списков")
                    continue
                if score > 0:
                    doc = {
                        'content': self.documents[i],
                        'metadata': self.document_metadata[i],
                        'id': self.document_ids[i],
                        'bm25_score': float(score),
                        'distance': 1.0 / (1.0 + score)
                    }
                    if where_clause:
                        if self._matches_filter(doc['metadata'], where_clause):
                            results.append(doc)
                    else:
                        results.append(doc)
            results.sort(key=lambda x: x['bm25_score'], reverse=True)
            return results[:top_k]
        except Exception as e:
            logger.error(f"Ошибка BM25 поиска: {e}")
            import traceback
            logger.error(traceback.format_exc())
            return []
    def _matches_filter(self, metadata: Dict, where_clause: Dict) -> bool:

        for key, value in where_clause.items():
            if key not in metadata:
                return False
            if metadata[key] != value:
                return False
        return True
    def update_index(self, new_documents: List[Dict[str, Any]]):

        all_documents = self.documents.copy()
        all_metadata = self.document_metadata.copy()
        all_ids = self.document_ids.copy()
        for doc in new_documents:
            content = doc.get('content', '')
            if content and content.strip():
                all_documents.append(content)
                all_metadata.append(doc.get('metadata', {}))
                all_ids.append(doc.get('id', ''))
        combined_docs = [
            {'content': content, 'metadata': meta, 'id': doc_id}
            for content, meta, doc_id in zip(all_documents, all_metadata, all_ids)
        ]
        self.build_index(combined_docs)
class ProperRetriever:

    def __init__(self, settings_file='retriever_settings.json'):
        self.settings_file = settings_file
        self.settings = self._load_settings()
        self._load_preprocessing_settings()
        self.bm25_index = BM25Index()
        self._reranker_model = None
        self._crossencoder_model = None
        self._onnx_model = None
        self._onnx_tokenizer = None
        self._onnx_tried = False
        self._bm25_index_loaded = False
    def _load_preprocessing_settings(self):

        try:
            import os
            parser_settings_file = 'parser_settings.json'
            if os.path.exists(parser_settings_file):
                import json
                with open(parser_settings_file, 'r', encoding='utf-8') as f:
                    parser_settings = json.load(f)
                    self.remove_extra_whitespace = parser_settings.get('remove_extra_whitespace', True)
                    self.normalize_unicode = parser_settings.get('normalize_unicode', True)
                    self.remove_special_chars = parser_settings.get('remove_special_chars', False)
                    self.lowercase_text = parser_settings.get('lowercase_text', False)
                    logger.debug(f"Загружены настройки предобработки: remove_extra_whitespace={self.remove_extra_whitespace}, normalize_unicode={self.normalize_unicode}, remove_special_chars={self.remove_special_chars}, lowercase_text={self.lowercase_text}")
            else:
                self.remove_extra_whitespace = True
                self.normalize_unicode = True
                self.remove_special_chars = False
                self.lowercase_text = False
        except Exception as e:
            logger.warning(f"Ошибка загрузки настроек предобработки: {e}, используются значения по умолчанию")
            self.remove_extra_whitespace = True
            self.normalize_unicode = True
            self.remove_special_chars = False
            self.lowercase_text = False
    def _preprocess_query(self, query: str) -> str:

        if not query:
            return query
        self._load_preprocessing_settings()
        if self.remove_extra_whitespace:
            import re
            query = re.sub(r'[ \t]+', ' ', query)
            query = re.sub(r'\n{3,}', '\n\n', query)
            query = query.strip()
        if self.normalize_unicode:
            import unicodedata
            query = unicodedata.normalize('NFC', query)
        if self.remove_special_chars:
            import re
            query = re.sub(r'[^\w\s\n\r.,;:!?()\-—–«»""''`]', '', query, flags=re.UNICODE)
        if self.lowercase_text:
            query = query.lower()
        return query
    def _load_settings(self) -> Dict[str, Any]:

        try:
            if os.path.exists(self.settings_file):
                with open(self.settings_file, 'r', encoding='utf-8') as f:
                    settings = json.load(f)
                    logger.info(f"Настройки ретривера загружены из {self.settings_file}")
                    return settings
        except Exception as e:
            logger.warning(f"Не удалось загрузить настройки из {self.settings_file}: {e}")
        default_settings = {
            'search_k': 100,
            'similarity_threshold': 0.001,
            'search_type': 'similarity',
            'mmr_lambda': 0.7,
            'rerank_results': 'simple',
            'max_results': 10,
            'filter_by_category': False,
            'filter_by_date': False,
            'filter_by_size': False
        }
        logger.info(f"Используются настройки по умолчанию: {default_settings}")
        return default_settings
    def _reload_settings(self):

        self.settings = self._load_settings()
    def search(self, query: str, **kwargs) -> List[Dict[str, Any]]:

        self._reload_settings()
        search_k = kwargs.get('search_k', self.settings.get('search_k', 100))
        search_type = kwargs.get('search_type', self.settings.get('search_type', 'similarity'))
        similarity_threshold = kwargs.get('similarity_threshold', self.settings.get('similarity_threshold', 0.2))
        max_results = kwargs.get('max_results', self.settings.get('max_results', 10))
        rerank_results = kwargs.get('rerank_results', self.settings.get('rerank_results', 'simple'))
        logger.info(f"Поиск: query='{query[:50]}...', type={search_type}, k={search_k}, threshold={similarity_threshold}, max={max_results}, rerank={rerank_results}")
        try:
            preprocessed_query = self._preprocess_query(query)
            query_embedding = embedding_service.create_embedding(preprocessed_query)
        except Exception as e:
            logger.error(f"Ошибка создания эмбеддинга: {e}")
            return []
        where_clause = self._build_where_clause()
        if where_clause:
            logger.debug(f"Применяются фильтры: {where_clause}")
        if search_type == 'mmr':
            documents = self._mmr_search(query_embedding, search_k, where_clause)
        elif search_type == 'hybrid':
            documents = self._hybrid_search(query, query_embedding, search_k, where_clause)
        else:
            documents = self._similarity_search(query_embedding, search_k, where_clause)
        logger.info(f"Найдено {len(documents)} документов после поиска")
        documents = self._apply_similarity_threshold(documents, similarity_threshold)
        logger.info(f"Осталось {len(documents)} документов после фильтрации по threshold={similarity_threshold}")
        if rerank_results != 'none':
            if rerank_results == 'advanced':
                documents = self._advanced_rerank(documents, query)
            else:
                documents = self._simple_rerank(documents, query)
            logger.info(f"После реранкинга: {len(documents)} документов")
        documents = documents[:max_results]
        logger.info(f"Финальный результат: {len(documents)} документов")
        return documents
    def _similarity_search(self, query_embedding, n_results, where_clause):

        try:
            collection = chroma_client.get_collection()
            results = collection.query(
                query_embeddings=[query_embedding],
                n_results=n_results,
                where=where_clause
            )
            return self._format_results(results)
        except Exception as e:
            logger.error(f"Ошибка similarity search: {e}")
            return []
    def _mmr_search(self, query_embedding, n_results, where_clause):

        mmr_lambda = self.settings.get('mmr_lambda', 0.7)
        try:
            collection = chroma_client.get_collection()
            candidates = collection.query(
                query_embeddings=[query_embedding],
                n_results=min(n_results * 2, 200),
                where=where_clause
            )
            documents = self._format_results(candidates)
            if not documents:
                return []
            selected = []
            selected_embeddings = []
            if documents:
                selected.append(documents[0])
                try:
                    selected_embeddings.append(embedding_service.create_embedding(documents[0]['content']))
                except Exception as e:
                    logger.warning(f"Ошибка создания эмбеддинга для MMR: {e}")
                    return documents[:n_results]
            remaining_docs = documents[1:]
            while len(selected) < n_results and remaining_docs:
                best_doc = None
                best_mmr_score = float('-inf')
                best_idx = -1
                for i, doc in enumerate(remaining_docs):
                    try:
                        doc_embedding = embedding_service.create_embedding(doc['content'])
                        relevance = 1 - doc['distance']
                        max_similarity = 0.0
                        if selected_embeddings:
                            similarities = [
                                self._cosine_similarity(doc_embedding, sel_emb)
                                for sel_emb in selected_embeddings
                            ]
                            max_similarity = max(similarities) if similarities else 0.0
                        mmr_score = mmr_lambda * relevance - (1 - mmr_lambda) * max_similarity
                        if mmr_score > best_mmr_score:
                            best_mmr_score = mmr_score
                            best_doc = doc
                            best_idx = i
                    except Exception as e:
                        logger.warning(f"Ошибка обработки документа в MMR: {e}")
                        continue
                if best_doc:
                    selected.append(best_doc)
                    try:
                        selected_embeddings.append(embedding_service.create_embedding(best_doc['content']))
                    except Exception as e:
                        logger.warning(f"Ошибка создания эмбеддинга: {e}")
                    remaining_docs.pop(best_idx)
                else:
                    break
            return selected
        except Exception as e:
            logger.error(f"Ошибка MMR search: {e}")
            return self._similarity_search(query_embedding, n_results, where_clause)
    def _ensure_bm25_index(self):

        if self._bm25_index_loaded and not self.bm25_index._needs_rebuild:
            return
        try:
            documents = self._load_all_documents_for_bm25()
            if documents:
                self.bm25_index.build_index(documents)
                self._bm25_index_loaded = True
                logger.info(f"BM25 индекс загружен: {len(documents)} документов")
            else:
                logger.warning("Нет документов для BM25 индекса")
        except Exception as e:
            logger.error(f"Ошибка загрузки BM25 индекса: {e}")
    def _load_all_documents_for_bm25(self) -> List[Dict[str, Any]]:

        try:
            collection = chroma_client.get_collection()
            result = collection.get()
            documents = []
            if result.get('documents'):
                for i, content in enumerate(result['documents']):
                    documents.append({
                        'content': content,
                        'metadata': result['metadatas'][i] if result.get('metadatas') else {},
                        'id': result['ids'][i] if result.get('ids') else f'doc_{i}'
                    })
            return documents
        except Exception as e:
            logger.error(f"Ошибка загрузки документов для BM25: {e}")
            return []
    def _hybrid_search(self, query, query_embedding, n_results, where_clause):

        try:
            rrf_candidates = max(n_results * 5, 100)
            vector_docs = self._similarity_search(query_embedding, rrf_candidates, where_clause)
            logger.debug(f"Векторный поиск вернул {len(vector_docs)} документов")
            bm25_docs = self._keyword_search(query, rrf_candidates, where_clause)
            logger.debug(f"BM25 поиск вернул {len(bm25_docs)} документов")
            if vector_docs and bm25_docs:
                combined = self._reciprocal_rank_fusion(vector_docs, bm25_docs, k=60)
                logger.debug(f"RRF объединил в {len(combined)} документов")
            elif vector_docs:
                combined = vector_docs
                logger.debug("Используем только векторный поиск (BM25 не вернул результатов)")
            elif bm25_docs:
                combined = bm25_docs
                logger.debug("Используем только BM25 поиск (векторный не вернул результатов)")
            else:
                combined = []
                logger.warning("Оба поиска не вернули результатов")
            result_count = min(len(combined), n_results * 3)
            logger.debug(f"Hybrid search возвращает {result_count} документов (из {len(combined)} после RRF, n_results={n_results})")
            return combined[:result_count]
        except Exception as e:
            logger.error(f"Ошибка hybrid search: {e}")
            import traceback
            logger.error(traceback.format_exc())
            return self._similarity_search(query_embedding, n_results, where_clause)
    def _keyword_search(self, query, n_results, where_clause):

        if not BM25_AVAILABLE:
            logger.warning("BM25 недоступен, используем fallback")
            return self._keyword_search_fallback(query, n_results, where_clause)
        try:
            self._ensure_bm25_index()
            results = self.bm25_index.search(query, top_k=n_results, where_clause=where_clause)
            formatted_results = []
            for doc in results:
                formatted_results.append({
                    'content': doc['content'],
                    'metadata': doc['metadata'],
                    'distance': doc['distance'],
                    'bm25_score': doc.get('bm25_score', 0.0)
                })
            return formatted_results
        except Exception as e:
            logger.error(f"Ошибка BM25 поиска: {e}")
            import traceback
            logger.error(traceback.format_exc())
            return self._keyword_search_fallback(query, n_results, where_clause)
    def _keyword_search_fallback(self, query, n_results, where_clause):

        try:
            query_embedding = embedding_service.create_embedding(query)
            collection = chroma_client.get_collection()
            results = collection.query(
                query_embeddings=[query_embedding],
                n_results=n_results,
                where=where_clause
            )
            docs = self._format_results(results)
            query_keywords = set(query.lower().split())
            stop_words = {'в', 'на', 'и', 'а', 'но', 'по', 'с', 'у', 'к', 'о', 'это', 'как', 'его', 'что', 'для'}
            query_keywords = {w for w in query_keywords if len(w) > 2 and w not in stop_words}
            for doc in docs:
                content_lower = doc['content'].lower()
                keyword_matches = sum(1 for kw in query_keywords if kw in content_lower)
                doc['keyword_boost'] = keyword_matches * 0.1
                doc['distance'] = max(0, doc['distance'] - doc['keyword_boost'])
            docs.sort(key=lambda x: x['distance'])
            return docs
        except Exception as e:
            logger.error(f"Ошибка fallback keyword search: {e}")
            return []
    def _apply_similarity_threshold(self, documents, threshold):

        if not documents:
            return []
        similarities = []
        for doc in documents:
            distance = doc.get('distance', 1.0)
            similarity = 1.0 - distance
            similarities.append((doc, similarity))
        sorted_similarities = sorted(similarities, key=lambda x: x[1], reverse=True)
        max_sim = sorted_similarities[0][1] if sorted_similarities else 0
        min_sim = sorted_similarities[-1][1] if sorted_similarities else 0
        max_dist = sorted_similarities[0][0].get('distance', 1.0) if sorted_similarities else 1.0
        min_dist = sorted_similarities[-1][0].get('distance', 1.0) if sorted_similarities else 1.0
        filtered = [doc for doc, sim in similarities if sim >= threshold]
        logger.info(f"Threshold {threshold}: max_sim={max_sim:.4f} (dist={max_dist:.4f}), min_sim={min_sim:.4f} (dist={min_dist:.4f}), прошло {len(filtered)} из {len(documents)}")
        if not filtered and documents:
            if threshold > max_sim:
                logger.info(f"Threshold {threshold} выше максимальной similarity {max_sim:.4f} (distance={max_dist:.4f}). Возвращаем пустой список (threshold работает корректно)")
                return []
            else:
                logger.info(f"Threshold {threshold} отфильтровал все документы (max_similarity={max_sim:.4f}, distance={max_dist:.4f}). Возвращаем пустой список")
                return []
        filtered.sort(key=lambda x: x.get('distance', 1.0))
        return filtered
    def _build_where_clause(self):

        where = {}
        if self.settings.get('filter_by_category') and self.settings.get('category'):
            where['category'] = self.settings['category']
        if self.settings.get('filter_by_date'):
            if self.settings.get('date_from'):
                where.setdefault('date', {})['$gte'] = self.settings['date_from']
            if self.settings.get('date_to'):
                where.setdefault('date', {})['$lte'] = self.settings['date_to']
            if 'date' not in where:
                logger.debug("filter_by_date=True, но даты не указаны - фильтр не применяется")
        if self.settings.get('filter_by_size'):
            if self.settings.get('min_size'):
                where.setdefault('size', {})['$gte'] = self.settings['min_size']
            if self.settings.get('max_size'):
                where.setdefault('size', {})['$lte'] = self.settings['max_size']
            if 'size' not in where:
                logger.debug("filter_by_size=True, но размеры не указаны - фильтр не применяется")
        return where if where else None
    def _simple_rerank(self, documents, query):

        query_lower = query.lower()
        query_keywords = set(query_lower.split())
        stop_words = {'в', 'на', 'и', 'а', 'но', 'по', 'с', 'у', 'к', 'о', 'это', 'как', 'его', 'что', 'для', 'когда', 'где', 'кто', 'какой'}
        query_keywords = {w for w in query_keywords if len(w) > 2 and w not in stop_words}
        query_phrase = query_lower.strip()
        reranked = []
        for doc in documents:
            content = doc.get('content', '').lower()
            distance = doc.get('distance', 1.0)
            phrase_match = query_phrase in content
            keyword_matches = sum(1 for kw in query_keywords if kw in content)
            if phrase_match:
                boost = 0.5
            elif keyword_matches > 0:
                boost = 1.0 - (keyword_matches * 0.1)
                boost = max(boost, 0.7)
            else:
                boost = 1.0
            boosted_distance = distance * boost
            reranked.append({
                **doc,
                'distance': boosted_distance,
                'boosted_distance': boosted_distance,
                'keyword_matches': keyword_matches,
                'phrase_match': phrase_match
            })
        reranked.sort(key=lambda x: x.get('distance', 1.0))
        return reranked
    def _onnx_rerank(self, documents, query):

        start_time = time.time()
        try:
            max_docs = int(os.getenv('RERANKER_MAX_DOCS', '30'))
            batch_size = int(os.getenv('RERANKER_BATCH_SIZE', '64'))
            docs_to_rerank = documents[:max_docs]
            logger.debug(f"ONNX реранкинг: {len(docs_to_rerank)} документов, batch_size={batch_size}")
            scores = []
            sep_token = self._onnx_tokenizer.sep_token
            for i in range(0, len(docs_to_rerank), batch_size):
                batch_docs = docs_to_rerank[i:i + batch_size]
                batch_pairs = [f"{query}{sep_token}{doc.get('content', '')}" for doc in batch_docs]
                batch_inputs = self._onnx_tokenizer(
                    batch_pairs,
                    return_tensors='np',
                    truncation=True,
                    max_length=512,
                    padding=True,
                    return_attention_mask=True
                )
                batch_outputs = self._onnx_model(**batch_inputs)
                batch_scores = batch_outputs.logits.squeeze(-1).tolist()
                scores.extend(batch_scores)
            reranked = []
            for i, doc in enumerate(docs_to_rerank):
                score = float(scores[i])
                doc_id = f"{doc.get('metadata', {}).get('source', 'unknown')}_chunk_{doc.get('metadata', {}).get('chunk_index', i)}"
                logger.debug(f"ONNX реранкинг: {doc_id} -> score={score:.4f}")
                reranked.append({
                    **doc,
                    'rerank_score': score,
                    'distance': 1.0 / (1.0 + abs(score))
                })
            reranked.sort(key=lambda x: x['rerank_score'], reverse=True)
            top5_info = []
            for r in reranked[:5]:
                source = r.get('metadata', {}).get('source', 'unknown')
                chunk_idx = r.get('metadata', {}).get('chunk_index', '?')
                score = r.get('rerank_score', 0)
                top5_info.append(f"{source}_chunk_{chunk_idx}({score:.4f})")
            logger.info(f"Топ-5 после ONNX реранкинга: {top5_info}")
            if len(documents) > max_docs:
                reranked.extend(documents[max_docs:])
            elapsed_time = time.time() - start_time
            logger.info(f"Использован ONNX реранкинг ({len(docs_to_rerank)} документов оценено за {elapsed_time:.3f} сек)")
            return reranked
        except Exception as e:
            logger.error(f"Ошибка ONNX reranking: {e}, используем обычный CrossEncoder")
            import traceback
            logger.error(traceback.format_exc())
            return self._crossencoder_rerank_fallback(documents, query)
    def _crossencoder_rerank_fallback(self, documents, query):

        start_time = time.time()
        if not CROSSENCODER_AVAILABLE:
            return self._simple_rerank(documents, query)
        try:
            if self._crossencoder_model is None:
                model_name = os.getenv('RERANKER_MODEL', 'BAAI/bge-reranker-base')
                logger.info(f"Загрузка CrossEncoder модели (fallback): {model_name}")
                self._crossencoder_model = CrossEncoder(model_name)
                logger.info("CrossEncoder модель загружена")
            max_docs = int(os.getenv('RERANKER_MAX_DOCS', '30'))
            batch_size = int(os.getenv('RERANKER_BATCH_SIZE', '64'))
            docs_to_rerank = documents[:max_docs]
            pairs = [(query, doc.get('content', '')) for doc in docs_to_rerank]
            scores = self._crossencoder_model.predict(pairs, batch_size=batch_size, show_progress_bar=False)
            reranked = []
            for i, doc in enumerate(docs_to_rerank):
                score = float(scores[i])
                reranked.append({
                    **doc,
                    'rerank_score': score,
                    'distance': 1.0 / (1.0 + score)
                })
            reranked.sort(key=lambda x: x['rerank_score'], reverse=True)
            if len(documents) > max_docs:
                reranked.extend(documents[max_docs:])
            elapsed_time = time.time() - start_time
            logger.info(f"Использован CrossEncoder реранкинг (fallback) ({len(docs_to_rerank)} документов оценено за {elapsed_time:.3f} сек)")
            return reranked
        except Exception as e:
            logger.error(f"Ошибка CrossEncoder fallback: {e}")
            return self._simple_rerank(documents, query)
    def _crossencoder_rerank(self, documents, query):

        if ONNX_AVAILABLE and not self._onnx_tried and self._onnx_model is None:
            self._onnx_tried = True
            try:
                model_name = os.getenv('RERANKER_MODEL', 'BAAI/bge-reranker-base')
                logger.info(f"Попытка загрузки ONNX реранкера: {model_name}")
                onnx_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'onnx_models', model_name.replace('/', '_'))
                os.makedirs(onnx_cache_dir, exist_ok=True)
                try:
                    logger.info(f"Пробуем загрузить ONNX модель из кэша: {onnx_cache_dir}")
                    import onnxruntime as ort
                    sess_options = ort.SessionOptions()
                    sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
                    sess_options.intra_op_num_threads = 0
                    sess_options.inter_op_num_threads = 0
                    from optimum.onnxruntime import ORTModelForSequenceClassification
                    self._onnx_model = ORTModelForSequenceClassification.from_pretrained(
                        onnx_cache_dir,
                        export=False,
                        provider='CPUExecutionProvider',
                        session_options=sess_options
                    )
                    self._onnx_tokenizer = AutoTokenizer.from_pretrained(onnx_cache_dir)
                    logger.info("ONNX модель загружена из кэша с оптимизированными настройками (быстро)")
                except Exception as cache_error:
                    logger.warning(f"Не удалось загрузить из кэша: {cache_error}, конвертируем заново...")
                    import shutil
                    hf_cache_path = os.path.join(os.path.expanduser('~'), '.cache', 'huggingface', 'hub', f'models--{model_name.replace("/", "--")}')
                    if os.path.exists(hf_cache_path):
                        try:
                            for root, dirs, files in os.walk(hf_cache_path):
                                if 'onnx' in dirs:
                                    onnx_dir = os.path.join(root, 'onnx')
                                    logger.info(f"Удаляем старую ONNX версию из кэша HuggingFace: {onnx_dir}")
                                    shutil.rmtree(onnx_dir, ignore_errors=True)
                        except Exception as e:
                            logger.warning(f"Не удалось очистить кэш HuggingFace: {e}")
                    import onnxruntime as ort
                    sess_options = ort.SessionOptions()
                    sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
                    sess_options.intra_op_num_threads = 0
                    sess_options.inter_op_num_threads = 0
                    self._onnx_model = ORTModelForSequenceClassification.from_pretrained(
                        model_name,
                        export=True,
                        provider='CPUExecutionProvider',
                        session_options=sess_options
                    )
                    self._onnx_tokenizer = AutoTokenizer.from_pretrained(model_name)
                    logger.info(f"Сохраняем ONNX модель в кэш: {onnx_cache_dir}")
                    self._onnx_model.save_pretrained(onnx_cache_dir)
                    self._onnx_tokenizer.save_pretrained(onnx_cache_dir)
                    logger.info("ONNX модель успешно конвертирована и сохранена в кэш с оптимизацией")
                logger.info("ONNX реранкер успешно загружен (быстрее чем CrossEncoder на CPU)")
            except Exception as e:
                logger.warning(f"Не удалось загрузить ONNX реранкер: {e}, используем обычный CrossEncoder (больше не будем пытаться)")
                import traceback
                logger.error(traceback.format_exc())
                self._onnx_model = None
                self._onnx_tokenizer = None
        if ONNX_AVAILABLE and self._onnx_model is not None and self._onnx_tokenizer is not None:
            return self._onnx_rerank(documents, query)
        if not CROSSENCODER_AVAILABLE:
            logger.warning("CrossEncoder не доступен, используем simple reranking")
            return self._simple_rerank(documents, query)
        start_time = time.time()
        try:
            if self._crossencoder_model is None:
                model_name = os.getenv('RERANKER_MODEL', 'BAAI/bge-reranker-base')
                logger.info(f"Загрузка CrossEncoder модели: {model_name}")
                self._crossencoder_model = CrossEncoder(model_name)
                logger.info("CrossEncoder модель загружена")
            max_docs = int(os.getenv('RERANKER_MAX_DOCS', '30'))
            batch_size = int(os.getenv('RERANKER_BATCH_SIZE', '64'))
            docs_to_rerank = documents[:max_docs]
            logger.debug(f"Реранкинг: {len(docs_to_rerank)} документов, batch_size={batch_size}")
            pairs = [(query, doc.get('content', '')) for doc in docs_to_rerank]
            scores = self._crossencoder_model.predict(pairs, batch_size=batch_size, show_progress_bar=False)
            reranked = []
            for i, doc in enumerate(docs_to_rerank):
                score = float(scores[i])
                doc_id = f"{doc.get('metadata', {}).get('source', 'unknown')}_chunk_{doc.get('metadata', {}).get('chunk_index', i)}"
                logger.debug(f"CrossEncoder реранкинг: {doc_id} -> score={score:.4f}")
                reranked.append({
                    **doc,
                    'rerank_score': score,
                    'distance': 1.0 / (1.0 + score)
                })
            reranked.sort(key=lambda x: x['rerank_score'], reverse=True)
            top5_info = []
            for r in reranked[:5]:
                source = r.get('metadata', {}).get('source', 'unknown')
                chunk_idx = r.get('metadata', {}).get('chunk_index', '?')
                score = r.get('rerank_score', 0)
                top5_info.append(f"{source}_chunk_{chunk_idx}({score:.4f})")
            logger.info(f"Топ-5 после CrossEncoder реранкинга: {top5_info}")
            if len(documents) > max_docs:
                reranked.extend(documents[max_docs:])
            elapsed_time = time.time() - start_time
            logger.info(f"Использован CrossEncoder реранкинг ({len(docs_to_rerank)} документов оценено за {elapsed_time:.3f} сек)")
            return reranked
        except Exception as e:
            logger.error(f"Ошибка CrossEncoder reranking: {e}, используем simple reranking")
            import traceback
            logger.error(traceback.format_exc())
            return self._simple_rerank(documents, query)
    def _advanced_rerank(self, documents, query):

        max_docs_for_rerank = int(os.getenv('RERANKER_MAX_DOCS', '100'))
        docs_to_rerank = documents[:max_docs_for_rerank]
        if RERANKER_SERVICE_AVAILABLE:
            try:
                logger.debug(f"Попытка использовать сервис реранкинга для балансировки нагрузки ({len(docs_to_rerank)} документов)")
                reranked = reranker_client.rerank(docs_to_rerank, query)
                if reranked is not None:
                    logger.info(f"Использован сервис реранкинга: {len(reranked)} документов")
                    return reranked
            except Exception as e:
                logger.warning(f"Ошибка сервиса реранкинга: {e}, используем локальный реранкинг")
        return self._crossencoder_rerank(documents, query)
    def _format_results(self, results):

        documents = []
        if results.get('documents') and results['documents'][0]:
            for i, doc in enumerate(results['documents'][0]):
                documents.append({
                    'content': doc,
                    'metadata': results['metadatas'][0][i] if results.get('metadatas') and results['metadatas'][0] else {},
                    'distance': results['distances'][0][i] if results.get('distances') and results['distances'][0] else 0.0
                })
        return documents
    def _merge_results(self, list1, list2):

        seen = set()
        merged = []
        for doc in list1 + list2:
            content_hash = hash(doc['content'])
            if content_hash not in seen:
                seen.add(content_hash)
                merged.append(doc)
        merged.sort(key=lambda x: x['distance'])
        return merged
    def _reciprocal_rank_fusion(self, list1: List[Dict], list2: List[Dict], k: int = 60) -> List[Dict]:

        rrf_scores = {}
        def get_doc_id(doc):

            metadata = doc.get('metadata', {})
            source = metadata.get('source', 'unknown')
            chunk_index = metadata.get('chunk_index', 0)
            return f"{source}_chunk_{chunk_index}"
        for rank, doc in enumerate(list1, start=1):
            doc_id = get_doc_id(doc)
            if doc_id not in rrf_scores:
                rrf_scores[doc_id] = {
                    'doc': doc,
                    'score': 0.0
                }
            rrf_scores[doc_id]['score'] += 1.0 / (k + rank)
        for rank, doc in enumerate(list2, start=1):
            doc_id = get_doc_id(doc)
            if doc_id not in rrf_scores:
                rrf_scores[doc_id] = {
                    'doc': doc,
                    'score': 0.0
                }
            rrf_scores[doc_id]['score'] += 1.0 / (k + rank)
        merged = []
        for doc_id, data in rrf_scores.items():
            doc = data['doc'].copy()
            doc['rrf_score'] = data['score']
            doc['distance'] = 1.0 / (1.0 + data['score'] * 100)
            merged.append(doc)
        logger.debug(f"RRF дедупликация: {len(list1)} + {len(list2)} -> {len(merged)} уникальных документов")
        merged.sort(key=lambda x: x['rrf_score'], reverse=True)
        return merged
    def _cosine_similarity(self, vec1, vec2):

        try:
            vec1 = np.array(vec1)
            vec2 = np.array(vec2)
            dot_product = np.dot(vec1, vec2)
            norm1 = np.linalg.norm(vec1)
            norm2 = np.linalg.norm(vec2)
            if norm1 == 0 or norm2 == 0:
                return 0.0
            return dot_product / (norm1 * norm2)
        except Exception as e:
            logger.warning(f"Ошибка вычисления cosine similarity: {e}")
            return 0.0
proper_retriever = ProperRetriever()