Source code for rupsycho.mixins.experiment_processing

# ===========================================================================
#                      ExperimentProcessingMixin Implementation
# ===========================================================================
#  This mixin provides methods for processing experiments by creating and
#  executing chains with different models, profiles, and instructions. It
#  includes error handling, progress tracking, and result storage.


import gc
import torch
import warnings
from collections import defaultdict
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain_huggingface.llms import HuggingFacePipeline
from rupsycho.utils import import_tqdm
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
from typing import Any, Dict, Optional
tqdm = import_tqdm()  # Import tqdm based on the environment

# ------------------- DEFAULT MODEL -------------------


[docs] def get_default_model() -> HuggingFacePipeline: params = { "min_new_tokens": 1, "max_new_tokens": 64, "temperature": 0.6, "do_sample": True, } model_id = "google/flan-t5-small" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForSeq2SeqLM.from_pretrained(model_id) pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer, **params) return HuggingFacePipeline(pipeline=pipe)
# ------------------------------------------------
[docs] class ExperimentProcessingMixin: """ A mixin that provides methods for processing experiments with various models, profiles, and instructions. This class is designed to facilitate the creation and execution of experiment chains, manage errors, track progress, and store results. """ def _create_input_dict(self, profile: Any, instruction_item: Any) -> Dict[str, Any]: """Create the input dictionary required for invoking the chain.""" # Get the answer options for the instruction item, or use the default options if none are provided if instruction_item.answer_options: answer_options = instruction_item.answer_options.join_options() else: answer_options = self.questionnaire.default_answer_options.join_options() # Return the input dictionary return { "general_instruction": self.questionnaire.general_instruction, "persona_description": profile.get_profile_desc(), "question": instruction_item.question, "answer_options": answer_options # Joined options with the specified delimiter } def _get_chain(self, prompt_template: Runnable, model: Runnable, parser: Runnable, name: str, seed: int = 42, params: dict = {}) -> Any: """Create a chain for running the experiment using the model identified by model_key.""" # Check if the prompt template, model, and parser are set if not prompt_template: raise ValueError("Prompt template not set.") if not model: raise ValueError("Model not set.") if not parser: parser = StrOutputParser() # Create the chain return ( RunnablePassthrough() | prompt_template | model.bind(seed=int(seed)) # TODO: Add params | parser ).with_config(run_name=name) def _generate_answer(self, chain, input_values): """Invoke the chain to generate an answer.""" try: answer = chain.invoke(input_values) except Exception as e: print(f"Error invoking chain for run: {e}") answer = None return answer def _ensure_requirements_to_run(self) -> None: """Ensure that at least one model and a prompt are set.""" if not self.runnable_models: raise ValueError("No models have been set in runnable_models.") if self.runnable_prompt is None: raise ValueError("runnable_prompt has not been set.") if self.questionnaire is None: raise ValueError("questionnaire has not been set.") if self.demographic_profiles is None: raise ValueError("demographic_profiles has not been set.") return True def _get_seed_values(self) -> list: """Return the seed values to be used in the experiment.""" return self.parameters.seeds if self.parameters.seeds else [42] def _calculate_total_iterations(self) -> int: """Calculate the total number of iterations for the progress bar.""" return ( len(self.questionnaire.instruction_items) * len(self.runnable_models) * len(self.demographic_profiles) * len(self._get_seed_values()) ) def _is_runnable(self, model): """Check if the model is a runnable LangChain model.""" return isinstance(model, Runnable) def _cleanup_memory(self): """Cleanup memory by running garbage collection and emptying the cache.""" gc.collect() torch.cuda.empty_cache() def _load_model(self, model): """Lazy load the model if it's not runnable.""" if not self._is_runnable(model): return model.load_model() return model def _generate_and_process_answers(self, model, model_id, seed_values, params, questionnaire, demographic_profiles, callbacks, pbar): """Generate answers for each instruction item, random seed, and demographic profile.""" # Iterate over random seeds for random_seed in seed_values: # Generate the chain for the current model and seed combination. chain = self._get_chain(self.runnable_prompt, model, self.runnable_parser, model_id, seed=int(random_seed), params=params) # Iterate over each instruction item in the questionnaire for instruction_item_id, instruction_item in enumerate(questionnaire.instruction_items): # Iterate over each demographic profile for profile_id, profile in demographic_profiles.items(): # Create an input dictionary based on the profile and instruction item input_values = self._create_input_dict( profile, instruction_item) # Generate an answer using the chain answer = self._generate_answer(chain, input_values) # Update the instruction item with the generated answer. instruction_item.update_answer( model_id, profile_id, random_seed, answer) # Trigger all callbacks to save the answer self._trigger_callbacks( callbacks, instruction_item_id, instruction_item, model_id, profile_id, random_seed, answer) # Update the progress bar pbar.update(1) def _trigger_callbacks(self, callbacks, instruction_item_id, instruction_item, model_id, profile_id, random_seed, answer): """Trigger all the callbacks to save the generated answer.""" for callback in callbacks: try: callback.save_answer( self, instruction_item_id, instruction_item, model_id, profile_id, random_seed, answer) except Exception as e: warnings.warn(f"Error while saving answer: {e}")
[docs] def process_single_experiment(self, pbar=None, callbacks=[]) -> None: """Process the experiment, iterating through models, profiles, and instruction items.""" # ------------------- Validation ------------------- self._ensure_requirements_to_run() # Create a progress bar for tracking the experiment progress if not provided if pbar is None: pbar = tqdm(total=self._calculate_total_iterations(), desc=self.name or "Experiment", unit=" prompts") # Precompute default parameters once as they do not change per model default_params = self.parameters.model_dump( exclude_none=True, exclude=['seeds']) # Pre-fetch reusable properties seed_values = self._get_seed_values() demographic_profiles = self.demographic_profiles runnable_models = self.runnable_models questionnaire = self.questionnaire # ------------------- Model Processing ------------------- for model_id, model in runnable_models.items(): # Load the model (lazy loading if required) model = self._load_model(model) # Update the model reference self.runnable_models[model_id] = model # Retrieve model-specific parameters if available, and merge them with the default parameters. model_config = self.models.get(model_id) model_params = model_config.parameters if model_config else {} params = {**default_params, **model_params} # Generate answers for the current model self._generate_and_process_answers( model, model_id, seed_values, params, questionnaire, demographic_profiles, callbacks, pbar) # Cleanup memory after processing the model del model self.runnable_models[model_id] = None # Remove from the dictionary gc.collect() # Cleanup CPU memory torch.cuda.empty_cache() # Cleanup GPU memory
[docs] def run(self, callbacks=[]) -> None: """Run the experiment processing with a progress bar and callbacks.""" # Ensure that the required attributes are set self._ensure_requirements_to_run() # Create a progress bar for tracking the experiment progress with tqdm(total=self._calculate_total_iterations(), desc=self.name or "Experiment", unit=" prompts") as pbar: self.process_single_experiment(pbar, callbacks=callbacks)