# Neural Correlates of Memory and Attention Study 2025
# inspired by: Shatek, S. M., Grootswagers, T., Robinson, A. K., & Carlson, T. A. (2019). Decoding Images in the Mind’s Eye: The Temporal Dynamics of Visual Imagery. Vision, 3(4), 53. https://doi.org/10.3390/vision3040053

## IMPORT MODULES
import os
import random
import csv
import datetime
from collections import Counter

import numpy as np
from psychopy import visual, core, event
import pandas as pd

# -------------------- Trigger I/O --------------------
# Choose ONE of: 'parallel', 'serial', or 'none'
TRIGGER_BACKEND = 'none'        # 'parallel' | 'serial' | 'none'
TRIGGER_PULSE_MS = 10           # duration of the TTL pulse in ms
TRIGGER_CODE_BLACK_ONSET = 11   # code to send when black screen begins; this might need to be adjusted depending on the trigger box used

# Parallel-port config (e.g., BrainProducts TriggerBox in LPT mode or true LPT)
PARALLEL_ADDRESS = 0x0378       # common: 0x0378 (LPT1), 0x03BC, 0x0278

# Serial config (USB-serial trigger boxes)
SERIAL_PORT = 'COM3'
SERIAL_BAUD = 115200

class TriggerSender:
    def __init__(self, backend='none'):
        self.backend = backend
        self.pulse_s = TRIGGER_PULSE_MS / 1000.0
        self.ready = False
        if backend == 'parallel':
            try:
                from psychopy.hardware import parallel
                self.port = parallel.ParallelPort(address=PARALLEL_ADDRESS)
                self.port.setData(0)
                self.ready = True
            except Exception as e:
                print(f"⚠️ Parallel trigger init failed: {e}. Falling back to 'none'.")
                self.backend = 'none'
        elif backend == 'serial':
            try:
                import serial
                self.ser = serial.Serial(SERIAL_PORT, SERIAL_BAUD, timeout=0)
                self.ready = True
            except Exception as e:
                print(f"⚠️ Serial trigger init failed: {e}. Falling back to 'none'.")
                self.backend = 'none'
        else:
            self.ready = True  # 'none' backend is a no-op

    def send_now(self, code: int):
        """Send immediately (blocking for pulse to clear)."""
        if not self.ready:
            return
        code = int(code) & 0xFF
        if self.backend == 'parallel':
            self.port.setData(code)
            core.wait(self.pulse_s)
            self.port.setData(0)
        elif self.backend == 'serial':
            self.ser.write(bytes([code]))
            core.wait(self.pulse_s)
            self.ser.write(b'\x00')
        else:
            pass  # no-op


# ----- Set working directory to script location -----
os.chdir(os.path.dirname(__file__))

# -------------------- Config --------------------
SEED = None  
if SEED is not None:
    random.seed(SEED)
    np.random.seed(SEED)

N_CATEGORIES = 19
IMAGES_PER_CATEGORY = 4
PICKED_CATEGORIES = 2
IMAGES_PER_PICKED_CATEGORY = 2
TRIALS = 96          # in order to have all possibilities of image sequences and selected target
PRACTICE_TRIALS = 3 

# durations all in s (seconds)
IMAGE_DURATION = 1.000 # informed by: Potter, M. C. (2012). Recognition and Memory for Briefly Presented Scenes. Frontiers in Psychology, 3, 20118. https://doi.org/10.3389/fpsyg.2012.00032
NOISE1_DURATION = 0.500 
NOISE2_DURATION = 2.000
BLACK_DURATION = 3.000
FEEDBACK_DURATION = 1.0   # practice feedback display time

# Paths (relative to script folder)
IMAGE_ROOT = "images"                    # main-task images root
PRACTICE_IMAGE_ROOT = "images_practice"  # practice images root
# images from: A. Khosla, A. S. Raju, A. Torralba and A. Oliva, "Understanding and Predicting Image Memorability at a Large Scale," 2015 IEEE International Conference on Computer Vision (ICCV), Santiago, Chile, 2015, pp. 2390-2398, doi: 10.1109/ICCV.2015.275. keywords: {Visualization;Games;Correlation;Benchmark testing;Delay effects;Delays;Computer vision},

WIN_SIZE = [1280, 720]
FULLSCREEN = True
BG_COLOR = "black"
TEXT_COLOR = "white"
FPS = 60

# -------------------- Image Setup --------------------
def build_image_library(root):
    """
    Returns:
        dict: {category_name: [list_of_image_paths]}
        dict: {image_path: memorability_score}
    Assumes filenames like: img1_0.836912578.jpg or any pattern ending in _<float>.<ext>
    """
    lib = {}
    scores = {}
    if not os.path.isdir(root):
        return lib, scores

    for entry in sorted(os.listdir(root)):
        cat_dir = os.path.join(root, entry)
        if not os.path.isdir(cat_dir) or not entry.startswith("category_"):
            continue
        cat_name = entry
        imgs = sorted([
            os.path.join(cat_dir, f) for f in os.listdir(cat_dir)
            if f.lower().endswith((".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff"))
        ])
        valid_imgs = []
        for img_path in imgs[:IMAGES_PER_CATEGORY]:
            filename = os.path.basename(img_path)
            try:
                underscore_part = filename.rsplit("_", 1)[-1]  # e.g., '0.836912578.jpg'
                score_str = os.path.splitext(underscore_part)[0]
                score = float(score_str)
                scores[img_path] = score
                valid_imgs.append(img_path)
            except Exception as e:
                print(f"⚠️ Failed to parse score from filename: {filename} — {e}")
        if len(valid_imgs) >= IMAGES_PER_PICKED_CATEGORY:
            lib[cat_name] = valid_imgs
    return lib, scores

def choose_four_images(image_lib): # chosing 4 images from 2 categories
    cats = random.sample(list(image_lib.keys()), PICKED_CATEGORIES)
    chosen = []
    for cat in cats:
        chosen.extend(random.sample(image_lib[cat], IMAGES_PER_PICKED_CATEGORY))
    assert len(chosen) == 4
    return chosen

def counterbalance_positions(n_trials, n_items=4):
    base_orders = [
        [0,1,2,3], [1,2,3,0], [2,3,0,1], [3,0,1,2]
    ]
    orders = []
    for _ in range(n_trials // n_items):
        block = base_orders.copy()
        random.shuffle(block)
        orders.extend(block)
    return orders

def balanced_numbers(n_trials, n_items=4):
    counts = list(range(1, n_items + 1)) * (n_trials // n_items)
    random.shuffle(counts)
    return counts

# -------------------- Visual Setup --------------------
def build_image_stims(win, image_paths):
    square_size = min(win.size) * 0.6  # 60% of shorter screen dimension
    return [visual.ImageStim(win, image=p, size=(square_size, square_size)) for p in image_paths]

# ====== Noise helpers using normalized floats in [-1, 1] (Option A) ======
# use of dynamic noise inspired by: Wilson, H., Chen, X., Golbabaee, M., Proulx, M. J., & O’Neill, E. (2023). Feasibility of decoding visual information from EEG. Brain-Computer Interfaces, 11(1–2), 33–60. https://doi.org/10.1080/2326263X.2023.2287719
def _noise_frame(h=512, w=512, c=3):
    # Random in [0,1) -> scale to [-1,1], dtype float32
    return (np.random.rand(h, w, c).astype(np.float32) * 2.0) - 1.0

def make_noise_stim(win):
    arr = _noise_frame()
    return visual.ImageStim(win, image=arr, size=win.size, interpolate=False)

def show_dynamic_noise(win, duration_s, clock, fps=60):
    noise = make_noise_stim(win)
    start = clock.getTime()
    frame_time = 1.0 / fps
    while (clock.getTime() - start) < duration_s:
        noise.image = _noise_frame()
        noise.draw()
        win.flip()
        core.wait(frame_time / 2.0)
# ========================================================================

def show_images_sequence(win, img_stims, order, duration_s, clock):
    for idx in order:
        img_stims[idx].draw()
        win.flip()
        core.wait(duration_s)

def show_number_and_wait_space(win, number, text_stim): # prompt an image and wait until the participant is imagining it 
    text_stim.text = str(number)
    text_stim.draw()
    win.flip()
    rt_clock = core.Clock()
    keys = event.waitKeys(keyList=['space', 'escape'], timeStamped=rt_clock)
    for k, rt in keys:
        if k == 'escape':
            core.quit()
        if k == 'space':
            return rt

def show_black(win, duration_s, on_flip=None, clock=None):
    """
    Show a pure-black screen for duration_s seconds.
    If `on_flip` is provided, it runs exactly at the flip (first black frame). (for trigger)
    If `clock` is provided, capture both:
        - relative onset (s) from `clock` zero
        - absolute wall time (datetime) at the flip
    Returns dict: {"t_rel_s": float|None, "t_abs_dt": datetime|None}
    """
    onset_holder = {"t_rel_s": None, "t_abs_dt": None}

    def _onflip_combo():
        # Capture absolute wall clock at the exact flip:
        onset_holder["t_abs_dt"] = datetime.datetime.now()
        # Capture relative experiment time at the exact flip:
        if clock is not None:
            onset_holder["t_rel_s"] = clock.getTime()
        # Then optionally fire the hardware trigger:
        if on_flip is not None:
            on_flip()

    win.color = BG_COLOR
    win.callOnFlip(_onflip_combo)
    win.flip()
    core.wait(duration_s)
    return onset_holder

def show_choice_8_images(win, img_paths):
    """
    Build a randomized 8-tile grid (4 originals + 4 mirrors) and return:
      - selected index in DISPLAY order
      - stims in DISPLAY order (for feedback overlay)
      - header TextStim
      - display_items: list of dicts, each with:
          {'id_idx': 0..3, 'flip': bool, 'path': str}
    """
    spacing_x = win.size[0] * 0.2
    square_size = min(win.size) * 0.25
    x_positions = [-1.5 * spacing_x, -0.5 * spacing_x, 0.5 * spacing_x, 1.5 * spacing_x]
    # Row placement (top / bottom)
    row_shift = win.size[1] * 0.05
    base_row_y = win.size[1] * 0.25
    top_y = base_row_y - row_shift
    bottom_y = -base_row_y - row_shift

    header = visual.TextStim(
        win,
        text="Which image did you imagine?",
        color=TEXT_COLOR,
        height=36,
        pos=(0, win.size[1] * 0.40),
        wrapWidth=win.size[0] * 0.9,
        alignText='center'
    )

    # Build the 8 items (4 identities x {orig, mirror}) and RANDOMIZE order
    items = []
    for i in range(4):
        items.append({'id_idx': i, 'flip': False, 'path': img_paths[i]})
        items.append({'id_idx': i, 'flip': True,  'path': img_paths[i]})
    random.shuffle(items)

    # Lay out as 4 columns x 2 rows in shuffled order
    stims = []
    display_items = []
    for col in range(4):
        itm = items[col]
        stim = visual.ImageStim(
            win, image=itm['path'], size=(square_size, square_size),
            pos=(x_positions[col], top_y), flipHoriz=itm['flip']
        )
        stims.append(stim); display_items.append(itm)
    for col in range(4):
        itm = items[4 + col]
        stim = visual.ImageStim(
            win, image=itm['path'], size=(square_size, square_size),
            pos=(x_positions[col], bottom_y), flipHoriz=itm['flip']
        )
        stims.append(stim); display_items.append(itm)

    # --- CLICK HANDLING: require a fresh click that starts inside a tile ---
    mouse = event.Mouse(visible=True, win=win)

    # 1) Ensure no buttons are currently down (swallow carry-over clicks)
    while any(mouse.getPressed()):
        header.draw()
        for stim in stims: stim.draw()
        win.flip()

    # 2) Now wait for a new click (transition)
    pressed_prev = (0, 0, 0)
    while True:
        header.draw()
        for stim in stims: stim.draw()
        win.flip()

        pressed = mouse.getPressed()
        # Detect new left-click (down transition)
        if pressed[0] and not pressed_prev[0]:
            pos = mouse.getPos()
            # Only accept if the click STARTED inside a tile
            for i, stim in enumerate(stims):
                if stim.contains(pos):
                    # Wait briefly for button release to avoid double-triggers
                    while mouse.getPressed()[0]:
                        core.wait(0.01)
                    core.wait(0.1)
                    return i, stims, header, display_items
        pressed_prev = pressed

def parse_image_info(image_path):
    """
    Extract category name and image number from a file path like:
    '.../category_03/img2_0.83.jpg'
    """
    parts = image_path.replace("\\", "/").split("/")
    category = parts[-2]
    filename = parts[-1]
    if filename.startswith("img") and "_" in filename:
        image_number = int(filename[3:].split("_")[0])
    else:
        image_number = None
    return category, image_number

# -------------------- Feedback (practice only) --------------------
def show_feedback_highlight(win, stims, header, correct_idx, selected_idx, correct):
    """
    - If correct: header "Correct.", selected image framed GREEN.
    - If incorrect: header "Incorrect. The correct image is indicated below.",
      correct image framed GREEN and selected image framed RED.
    Swallows mouse presses so no clicks carry into the next screen.
    """
    header.text = "Correct." if correct else "Incorrect. The correct image is indicated below."

    def outline_for(stim, color):
        w, h = stim.size
        return visual.Rect(win, width=w + 10, height=h + 10, pos=stim.pos,
                           lineColor=color, fillColor=None, lineWidth=6)

    green_box = outline_for(stims[correct_idx], "green")
    red_box = outline_for(stims[selected_idx], "red") if not correct else None

    # Show timed feedback
    t0 = core.Clock()
    mouse = event.Mouse(visible=True, win=win)
    while t0.getTime() < FEEDBACK_DURATION:
        header.draw()
        for s in stims: s.draw()
        green_box.draw()
        if red_box: red_box.draw()
        win.flip()

    # --- Swallow any clicks before leaving feedback ---
    while any(mouse.getPressed()):
        header.draw()
        for s in stims: s.draw()
        green_box.draw()
        if red_box: red_box.draw()
        win.flip()
    core.wait(0.1)
    event.clearEvents()

# -------------------- Instructions Screen --------------------
def show_instructions(win):
    """
    Instructions with a 'Continue to practice' button.
    """
    instructions = (
        "In the following experiment, you will see a sequence of four images. "
        "Remember those images carefully.\n\n"
        "After the sequence, you will be prompted with a number from 1 to 4 that indicates "
        "one of the images in the sequence. Call that image to mind and imagine it as vividly as possible.\n\n"
        "As soon as you can “see” the image with your inner eyes, press the space bar. "
        "Hold that mental image during the black screen that follows.\n\n"
        "Then, you will be shown a selection of eight images and asked to choose the one you just imagined. "
        "This experiment will be repeated several times.\n\n"
        "Click *Continue to practice* to start."
    )

    text = visual.TextStim(win, text=instructions, color=TEXT_COLOR, height=28,
                           wrapWidth=win.size[0] * 0.8, pos=(0, 120), alignText='center')

    btn_width, btn_height = 320, 60
    btn_pos = (0, -200)
    button_rect = visual.Rect(win, width=btn_width, height=btn_height, pos=btn_pos,
                              lineColor='white', fillColor=None)
    button_label = visual.TextStim(win, text="Continue to practice", color=TEXT_COLOR,
                                   height=28, pos=btn_pos)

    mouse = event.Mouse(visible=True, win=win)

    while True:
        win.clearBuffer()
        text.draw()
        button_rect.draw()
        button_label.draw()
        win.flip()

        keys = event.getKeys(keyList=['return', 'space', 'escape'])
        if 'escape' in keys:
            core.quit()
        if 'return' in keys or 'space' in keys:
            core.wait(0.15)
            return

        if mouse.getPressed()[0]:
            if button_rect.contains(mouse.getPos()):
                core.wait(0.15)
                return

# -------------------- Trial Helper --------------------
def run_single_trial(
    win, clock, four_img_stims, chosen_paths, memorability_scores,
    order_first, num, text_stim, phase="main", give_feedback=False,
    trigger=None, experiment_clock=None
):
    """
    Runs one full trial and returns a results dict.
    Feedback (if enabled) uses exact identity + orientation (original only).
    The 8-image grid order is randomized on every call.
    """
    # 1) sequence
    show_images_sequence(win, four_img_stims, order_first, IMAGE_DURATION, clock)

    # 2) noise
    show_dynamic_noise(win, NOISE1_DURATION, clock, fps=FPS)

    # 3) cue -> space
    rt = show_number_and_wait_space(win, num, text_stim)

    # 4) black + EEG trigger on the *flip* that displays black
    if trigger is None:
        def _noop(): pass
        onflip = _noop
    else:
        onflip = lambda: trigger.send_now(TRIGGER_CODE_BLACK_ONSET)

    onset_info = show_black(
        win, BLACK_DURATION,
        on_flip=onflip,
        clock=experiment_clock  # capture both rel and abs times at flip
    )
    # Derive the requested timing fields
    t_rel_ms = None
    t_abs_str = None
    if onset_info["t_rel_s"] is not None:
        t_rel_ms = int(round(onset_info["t_rel_s"] * 1000.0))
    if onset_info["t_abs_dt"] is not None:
        # Format absolute time as HH:MM:SS.mmm
        t_abs_str = onset_info["t_abs_dt"].strftime("%H:%M:%S.%f")[:-3]

    # 5) 8-choice (randomized layout)
    rt_clock = core.Clock()
    selected_idx, stims, header, display_items = show_choice_8_images(win, chosen_paths)
    choice_rt = rt_clock.getTime()  # <-- selection RT captured

    # Compute correctness (exact identity + orientation)
    target_img_idx = order_first[num - 1]          # index 0..3 into chosen_paths
    target_path = chosen_paths[target_img_idx]
    target_cat, target_img_num = parse_image_info(target_path)

    sel_info = display_items[selected_idx]
    selected_flip = sel_info['flip']               # True if mirror
    selected_path = sel_info['path']
    selected_cat, selected_img_num = parse_image_info(selected_path)
    target_memorability = memorability_scores.get(target_path, "NA")

    correct_match = (selected_cat == target_cat and
                     selected_img_num == target_img_num and
                     selected_flip == False)       # must be original (not mirrored)
    correct_identity = (selected_cat == target_cat and selected_img_num == target_img_num)

    if give_feedback: # for practice trials only
        # Find which tile is the correct (original) target in the CURRENT randomized display
        correct_idx = None
        for i, itm in enumerate(display_items):
            if itm['id_idx'] == target_img_idx and itm['flip'] == False:
                correct_idx = i
                break
        if correct_idx is None:
            correct_idx = 0
        show_feedback_highlight(win, stims, header, correct_idx, selected_idx, correct_match)

    # 6) inter-trial noise
    show_dynamic_noise(win, NOISE2_DURATION, clock, fps=FPS)

    # Note: trial_stopped_on will be set when appending this row in the main loop
    return {
        "phase": phase, # whether practice or main 
        "order_first": order_first, # sequence of selected images
        "number_shown": num, # which image in the sequence was prompted
        "target_image_path": target_path, # prompted image = target 
        "target_category": target_cat,
        "target_image_number": target_img_num,
        "target_memorability": target_memorability, # mem score of prompted image
        "reaction_time": rt, # time between prompted number and pressing space to trigger black screen signaling having identified and now imagining prompted image
        "selected_image_path": selected_path, # referring to selected image in the 8 image grid
        "selected_category": selected_cat,
        "selected_image_number": selected_img_num,
        "selected_flipped": selected_flip,
        "choice_rt": choice_rt, # time to select one of the 8 images
        "correct_match": correct_match, # whether the selected image is exactly the image prompted (taking into account orientation)
        "correct_identity": correct_identity, # whether the selected image is the same as the image prompted NOT considering orientation (so can be mirror image)
        "black_onset_time": onset_info["t_rel_s"],  # seconds
        "black_onset_ms": t_rel_ms,                 # ms since experiment start
        "black_onset_abs_time": t_abs_str,          # HH:MM:SS.mmm
        "trial_stopped_on": None                    # placeholder; will be set by caller
    }

# -------------------- Main Experiment --------------------
def main():
    # ---------------- INITIALIZATION ---------------------
    participant_id = input("Enter participant ID: ").strip() # type participant ID into terminal 

    # --- Set up directory structure (save under data/<participant_id>/) ---
    script_dir = os.path.dirname(os.path.abspath(__file__))
    data_dir = os.path.join(script_dir, "data")
    participant_dir = os.path.join(data_dir, participant_id)
    os.makedirs(participant_dir, exist_ok=True)

    # --- Build output filenames (consistent base name) ---
    base_filename = f"{participant_id}_imagery_experiment"
    csv_filename = os.path.join(participant_dir, base_filename + ".csv")
    xlsx_filename = os.path.join(participant_dir, base_filename + ".xlsx")

    timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

    # PsychoPy window and clocks
    win = visual.Window(WIN_SIZE, fullscr=FULLSCREEN, color=BG_COLOR, units='pix')
    clock = core.Clock()
    experiment_clock = core.Clock()  # zero point for relative onsets (ms sheet)

    # Build trigger sender once and share it
    trigger = TriggerSender(backend=TRIGGER_BACKEND)

    # Instructions
    show_instructions(win)

    # ---- Load PRACTICE images
    practice_lib, practice_mem_scores = build_image_library(os.path.join(script_dir, PRACTICE_IMAGE_ROOT))
    if len(practice_lib) < 2:
        raise RuntimeError("Not enough practice categories found in 'images_practice'. "
                           "Expected at least category_01 and category_02.")
    practice_chosen_paths = choose_four_images(practice_lib)
    practice_four_img_stims = build_image_stims(win, practice_chosen_paths)

    # ---- Load MAIN images
    image_lib, memorability_scores = build_image_library(os.path.join(script_dir, IMAGE_ROOT))
    if len(image_lib) < 2:
        raise RuntimeError("Not enough main-task categories found in 'images'.")
    chosen_paths = choose_four_images(image_lib)
    four_img_stims = build_image_stims(win, chosen_paths)

    # Prep main orders / numbers
    orders = counterbalance_positions(TRIALS, n_items=4)
    numbers = balanced_numbers(TRIALS, n_items=4)
    text_stim = visual.TextStim(win, text="", color=TEXT_COLOR, height=60)

    results = []

    # -------------------- PRACTICE (uses practice images) --------------------
    practice_orders = []
    for _ in range(PRACTICE_TRIALS):
        order = list(range(4))
        random.shuffle(order)
        practice_orders.append(order)
    practice_numbers = [random.randint(1, 4) for _ in range(PRACTICE_TRIALS)]

    for p in range(PRACTICE_TRIALS):
        res = run_single_trial(
            win=win, clock=clock,
            four_img_stims=practice_four_img_stims,
            chosen_paths=practice_chosen_paths,
            memorability_scores=practice_mem_scores,
            order_first=practice_orders[p],
            num=practice_numbers[p],
            text_stim=text_stim,
            phase="practice",
            give_feedback=True,
            trigger=trigger,
            experiment_clock=experiment_clock
        )
        res["trial"] = p + 1
        res["trial_stopped_on"] = None  # not applicable for practice
        results.append(res)

        if p < PRACTICE_TRIALS - 1:
            pause_text = visual.TextStim(win, text="Press space to continue.", height=40, color='white')
            pause_text.draw()
            win.flip()
            keys = event.waitKeys(keyList=['space', 'escape'])
            if 'escape' in keys:
                confirm_text = visual.TextStim(win, text="Are you sure you want to quit? (y/n)", height=40, color='white')
                confirm_text.draw()
                win.flip()
                confirm_keys = event.waitKeys(keyList=['y', 'n'])
                if 'y' in confirm_keys:
                    text_stim.text = "Experiment ended."
                    text_stim.draw()
                    win.flip()
                    core.wait(2.0)
                    win.close()
                    return

    # Start main message
    start_main = visual.TextStim(win, text="Practice complete. Press space to begin the main task.", height=40, color='white')
    start_main.draw()
    win.flip()
    _ = event.waitKeys(keyList=['space', 'escape'])
    if 'escape' in _:
        text_stim.text = "Experiment ended."
        text_stim.draw()
        win.flip()
        core.wait(2.0)
        win.close()
        return

    # -------------------- MAIN (uses main images) --------------------
    interrupted = False
    trial_stopped_on = 0

    try:
        for t in range(TRIALS):
            order_first = orders[t]
            num = numbers[t]

            res = run_single_trial(
                win=win, clock=clock,
                four_img_stims=four_img_stims,
                chosen_paths=chosen_paths,
                memorability_scores=memorability_scores,
                order_first=order_first,
                num=num,
                text_stim=text_stim,
                phase="main",
                give_feedback=False,  # no feedback during main
                trigger=trigger,
                experiment_clock=experiment_clock
            )
            res["trial"] = t + 1
            trial_stopped_on = t + 1
            res["trial_stopped_on"] = trial_stopped_on  # <-- recorded per main trial
            results.append(res)

            # Optional pause
            pause_text = visual.TextStim(win, text="Press space to continue.", height=40, color='white')
            pause_text.draw()
            win.flip()
            keys = event.waitKeys(keyList=['space', 'escape'])
            if 'escape' in keys:
                confirm_text = visual.TextStim(win, text="Are you sure you want to quit? (y/n)", height=40, color='white')
                confirm_text.draw()
                win.flip()
                confirm_keys = event.waitKeys(keyList=['y', 'n'])
                if 'y' in confirm_keys:
                    interrupted = True
                    break

    except Exception as e:
        print("⚠️ Experiment crashed with error:", e)
        interrupted = True

    finally:
        # Decide output filenames (append suffix if interrupted)
        base_out = base_filename + ("_INTERRUPTED" if interrupted else "")
        csv_path = os.path.join(participant_dir, base_out + ".csv")
        xlsx_path = os.path.join(participant_dir, base_out + ".xlsx")

        # Closing
        text_stim2 = visual.TextStim(win, text=("Experiment ended." if interrupted else "Thank you!"),
                                     color=TEXT_COLOR, height=60)
        text_stim2.draw()
        win.flip()
        core.wait(2.0)
        win.close()

        if len(results) > 0:
            # ------ STORE RESULTS (CSV) ---------------------------------
            with open(csv_filename, "w", newline='') as f:
                writer = csv.DictWriter(f, fieldnames=results[0].keys())
                writer.writeheader()
                writer.writerows(results)

            # ------ STORE RESULTS (Excel with three sheets) -------------
            df = pd.DataFrame(results)

            # ---- summary should ONLY consider actual experimental (main) trials
            df_main = df[df["phase"] == "main"].copy()

            # Counts
            total_trials_main = len(df_main)
            total_correct_match = int(df_main["correct_match"].sum()) if "correct_match" in df_main else 0
            total_correct_identity = int(df_main["correct_identity"].sum()) if "correct_identity" in df_main else 0
            total_semantic_match = total_correct_identity - total_correct_match

            # Mean cue→space RTs (reaction_time)
            mean_rt_all = float(df_main["reaction_time"].mean()) if "reaction_time" in df_main else float('nan')
            mean_rt_correct_match = float(df_main.loc[df_main["correct_match"] == True, "reaction_time"].mean()) if "reaction_time" in df_main else float('nan')
            mean_rt_correct_identity = float(df_main.loc[df_main["correct_identity"] == True, "reaction_time"].mean()) if "reaction_time" in df_main else float('nan')
            mean_rt_incorrect = float(df_main.loc[df_main["correct_identity"] == False, "reaction_time"].mean()) if "reaction_time" in df_main else float('nan')

            # Mean choice RTs (8-image selection)
            mean_choice_rt_all = float(df_main["choice_rt"].mean()) if "choice_rt" in df_main else float('nan')
            mean_choice_rt_correct_match = float(df_main.loc[df_main["correct_match"] == True, "choice_rt"].mean()) if "choice_rt" in df_main else float('nan')
            mean_choice_rt_correct_identity = float(df_main.loc[df_main["correct_identity"] == True, "choice_rt"].mean()) if "choice_rt" in df_main else float('nan')
            mean_choice_rt_incorrect = float(df_main.loc[df_main["correct_identity"] == False, "choice_rt"].mean()) if "choice_rt" in df_main else float('nan')

            # Mean memorability by outcome (target_memorability)
            mean_mem_correct_match = float(df_main.loc[df_main["correct_match"] == True, "target_memorability"].mean()) if "target_memorability" in df_main else float('nan')
            mean_mem_correct_identity = float(df_main.loc[df_main["correct_identity"] == True, "target_memorability"].mean()) if "target_memorability" in df_main else float('nan')
            mean_mem_incorrect = float(df_main.loc[df_main["correct_identity"] == False, "target_memorability"].mean()) if "target_memorability" in df_main else float('nan')

            summary_df = pd.DataFrame({
                "trials_completed": [total_trials_main], # how many main trials were completed
                "total_correct_matches": [total_correct_match], # how many times exactly the prompted image was identified in the grid
                "total_semantic_match": [total_semantic_match], # how many times the mirror image was selected in the grid

                "mean_reaction_time": [mean_rt_all], # mean time between prompt and pressing space
                "mean_rt_correct_matches": [mean_rt_correct_match], # mean reaction_time over trials with correct_match = True
                "mean_rt_correct_identity": [mean_rt_correct_identity], # mean reaction_time where correct_identity = True
                "mean_rt_incorrect_matches": [mean_rt_incorrect], # mean reaction_time where correct_identity = False

                "mean_choice_rt": [mean_choice_rt_all], # mean of choice_rt over all main trials
                "mean_choice_rt_correct_matches": [mean_choice_rt_correct_match], # mean choice_rt where correct_match = True
                "mean_choice_rt_correct_identity": [mean_choice_rt_correct_identity], # mean choice_rt where correct_identity = True
                "mean_choice_rt_incorrect_matches": [mean_choice_rt_incorrect], # mean choice_rt where correct_identity = False

                "mean_mem_correct_matches": [mean_mem_correct_match], # mean target_memorability where correct_match = True
                "mean_mem_correct_identity": [mean_mem_correct_identity], # mean target_memorability where correct_identity = True
                "mean_mem_incorrect_matches": [mean_mem_incorrect], # mean target_memorability where correct_identity = False
            })

            # ---- black_onsets sheet
            # trial number, phase, black onset in ms, absolute time (HH:MM:SS.mmm)
            cols_needed = ["trial", "phase", "black_onset_ms", "black_onset_abs_time"]
            # Some rows (if any unexpected) might miss fields; ensure columns exist:
            for c in cols_needed:
                if c not in df.columns:
                    df[c] = None
            black_df = df.loc[:, cols_needed].copy()
            # Sort by phase then trial for readability
            black_df = black_df.sort_values(["phase", "trial"])

            # Write sheets; prefer openpyxl, fall back to xlsxwriter, else CSVs
            try:
                with pd.ExcelWriter(xlsx_path, engine="openpyxl") as writer:
                    df.to_excel(writer, sheet_name="trials", index=False)
                    summary_df.to_excel(writer, sheet_name="summary", index=False)
                    black_df.to_excel(writer, sheet_name="black_onsets", index=False)
            except Exception as e_openpyxl:
                try:
                    with pd.ExcelWriter(xlsx_path, engine="xlsxwriter") as writer:
                        df.to_excel(writer, sheet_name="trials", index=False)
                        summary_df.to_excel(writer, sheet_name="summary", index=False)
                        black_df.to_excel(writer, sheet_name="black_onsets", index=False)
                except Exception as e_xlsxwriter:
                    # CSV fallbacks match the same suffix
                    csv1 = os.path.join(os.path.dirname(xlsx_path), base_out + "_trials.csv")
                    csv2 = os.path.join(os.path.dirname(xlsx_path), base_out + "_summary.csv")
                    csv3 = os.path.join(os.path.dirname(xlsx_path), base_out + "_black_onsets.csv")
                    df.to_csv(csv1, index=False)
                    summary_df.to_csv(csv2, index=False)
                    black_df.to_csv(csv3, index=False)
                    print("⚠️ Could not write Excel file (openpyxl/xlsxwriter missing). Saved CSVs instead:")
                    print(" -", csv1)
                    print(" -", csv2)
                    print(" -", csv3)

            print(f"\n✅ Results saved to:\n  {xlsx_path}")

if __name__ == "__main__":
    main()
