# Neural Correlates of Memory and Attention Study 2025
# inspired by: Shatek, S. M., Grootswagers, T., Robinson, A. K., & Carlson, T. A. (2019). Decoding Images in the Mind’s Eye: The Temporal Dynamics of Visual Imagery. Vision, 3(4), 53. https://doi.org/10.3390/vision3040053

## IMPORT MODULES
import os
import random
import csv
import datetime
from collections import Counter

import numpy as np
from psychopy import visual, core, event
import pandas as pd

# ----- Photodiode optical trigger -----
PHOTODIODE_FLASH_MS = 120       # ≥100 ms
PHOTODIODE_SIZE_PX = 60         # small square
PHOTODIODE_CORNER = 'top-right' # 'top-left' | 'top-right' | 'bottom-left' | 'bottom-right'

def _photodiode_square(win):
    sz = PHOTODIODE_SIZE_PX
    half_w, half_h = win.size[0]//2, win.size[1]//2

    # Corner positions in pixel coordinates (centered coordinate system)
    if PHOTODIODE_CORNER == 'top-left':
        pos = (-half_w + sz/2 + 5,  half_h - sz/2 - 5)
    elif PHOTODIODE_CORNER == 'top-right':
        pos = ( half_w - sz/2 - 5,  half_h - sz/2 - 5)
    elif PHOTODIODE_CORNER == 'bottom-left':
        pos = (-half_w + sz/2 + 5, -half_h + sz/2 + 5)
    else:  # bottom-right
        pos = ( half_w - sz/2 - 5, -half_h + sz/2 + 5)

    return visual.Rect(win, width=sz, height=sz, pos=pos, fillColor='white', lineColor='white')


# -------------------- Trigger I/O (only needed for purchased triggers) --------------------
TRIGGER_BACKEND = 'none'
TRIGGER_PULSE_MS = 10
TRIGGER_CODE_BLACK_ONSET = 11
PARALLEL_ADDRESS = 0x0378
SERIAL_PORT = 'COM3'
SERIAL_BAUD = 115200

class TriggerSender:
    def __init__(self, backend='none'):
        self.backend = backend
        self.pulse_s = TRIGGER_PULSE_MS / 1000.0
        self.ready = False
        if backend == 'parallel':
            try:
                from psychopy.hardware import parallel
                self.port = parallel.ParallelPort(address=PARALLEL_ADDRESS)
                self.port.setData(0)
                self.ready = True
            except Exception as e:
                print(f"⚠️ Parallel trigger init failed: {e}. Falling back to 'none'.")
                self.backend = 'none'
        elif backend == 'serial':
            try:
                import serial
                self.ser = serial.Serial(SERIAL_PORT, SERIAL_BAUD, timeout=0)
                self.ready = True
            except Exception as e:
                print(f"⚠️ Serial trigger init failed: {e}. Falling back to 'none'.")
                self.backend = 'none'
        else:
            self.ready = True

    def send_now(self, code: int):
        if not self.ready:
            return
        code = int(code) & 0xFF
        if self.backend == 'parallel':
            self.port.setData(code)
            core.wait(self.pulse_s)
            self.port.setData(0)
        elif self.backend == 'serial':
            self.ser.write(bytes([code]))
            core.wait(self.pulse_s)
            self.ser.write(b'\x00')
        else:
            pass

# ----- Set working directory to script location -----
os.chdir(os.path.dirname(__file__))

# -------------------- Config --------------------
SEED = None
if SEED is not None:
    random.seed(SEED)
    np.random.seed(SEED)

N_CATEGORIES = 19
IMAGES_PER_CATEGORY = 4
PICKED_CATEGORIES = 2
IMAGES_PER_PICKED_CATEGORY = 2
TRIALS = 96 # in order to have all possibilities of image sequences and selected target
PRACTICE_TRIALS = 3

# durations (s)
IMAGE_DURATION = 1.000  # informed by: Potter, M. C. (2012). Recognition and Memory for Briefly Presented Scenes. Frontiers in Psychology, 3, 20118. https://doi.org/10.3389/fpsyg.2012.00032
NOISE1_DURATION = 0.500
NOISE2_DURATION = 2.000    # inter-trial noise shown before next trial
BLACK_DURATION = 3.000
FEEDBACK_DURATION = 1.0 # practice feedback display time

# Vividness labels (1..5)
VIVIDNESS_LABELS = [
    "No image at all",
    "Dim and vague",
    "Moderately clear",
    "Clear and reasonably vivid",
    "Perfectly clear and as \n vivid as normal vision"
]

# paths
IMAGE_ROOT = "images" # main-task images root
PRACTICE_IMAGE_ROOT = "images_practice" # practice images root
# images from: A. Khosla, A. S. Raju, A. Torralba and A. Oliva, "Understanding and Predicting Image Memorability at a Large Scale," 2015 IEEE International Conference on Computer Vision (ICCV), Santiago, Chile, 2015, pp. 2390-2398, doi: 10.1109/ICCV.2015.275.

WIN_SIZE = [1280, 720]
FULLSCREEN = True
BG_COLOR = "black"
TEXT_COLOR = "white"
FPS = 60

# -------------------- Image Setup --------------------
def build_image_library(root):
    """
    Returns:
        dict: {category_name: [list_of_image_paths]}
        dict: {image_path: memorability_score}
    Assumes filenames like: img1_0.836912578.jpg or any pattern ending in _<float>.<ext>
    """
    lib = {}
    scores = {}
    if not os.path.isdir(root):
        return lib, scores
    for entry in sorted(os.listdir(root)):
        cat_dir = os.path.join(root, entry)
        if not os.path.isdir(cat_dir) or not entry.startswith("category_"):
            continue
        imgs = sorted([
            os.path.join(cat_dir, f) for f in os.listdir(cat_dir)
            if f.lower().endswith((".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff"))
        ])
        valid = []
        for img_path in imgs[:IMAGES_PER_CATEGORY]:
            filename = os.path.basename(img_path)
            try:
                score_str = os.path.splitext(filename.rsplit("_", 1)[-1])[0]
                scores[img_path] = float(score_str)
                valid.append(img_path)
            except Exception as e:
                print(f"⚠️ Failed to parse score from filename: {filename} — {e}")
        if len(valid) >= IMAGES_PER_PICKED_CATEGORY:
            lib[entry] = valid
    return lib, scores

def choose_four_images(image_lib):  # chosing 4 images from 2 categories
    cats = random.sample(list(image_lib.keys()), PICKED_CATEGORIES)
    chosen = []
    for cat in cats:
        chosen.extend(random.sample(image_lib[cat], IMAGES_PER_PICKED_CATEGORY))
    assert len(chosen) == 4
    return chosen

def counterbalance_positions(n_trials, n_items=4):
    base_orders = [[0,1,2,3],[1,2,3,0],[2,3,0,1],[3,0,1,2]]
    orders = []
    for _ in range(n_trials // n_items):
        block = base_orders.copy()
        random.shuffle(block)
        orders.extend(block)
    return orders

def balanced_numbers(n_trials, n_items=4):
    counts = list(range(1, n_items + 1)) * (n_trials // n_items)
    random.shuffle(counts)
    return counts

# -------------------- Visual helpers --------------------
def draw_progress_bar(win, fraction: float):
    fraction = max(0.0, min(1.0, float(fraction)))
    bar_width = win.size[0] * 0.8
    bar_height = 20
    y = -win.size[1] * 0.45
    frame = visual.Rect(win, width=bar_width, height=bar_height, pos=(0, y),
                        lineColor='white', fillColor=None, lineWidth=2)
    filled_width = bar_width * fraction
    filled = visual.Rect(win, width=filled_width, height=bar_height,
                         pos=(-bar_width/2 + filled_width/2, y),
                         lineColor=None, fillColor='white')
    frame.draw(); filled.draw()

def overlay_with_progress(win, fraction):
    def _run():
        draw_progress_bar(win, fraction)
    return _run

def build_image_stims(win, image_paths):
    square_size = min(win.size) * 0.6
    return [visual.ImageStim(win, image=p, size=(square_size, square_size)) for p in image_paths]

# ====== Noise helpers ======
# use of dynamic noise inspired by: Wilson, H., Chen, X., Golbabaee, M., Proulx, M. J., & O’Neill, E. (2023). Feasibility of decoding visual information from EEG. Brain-Computer Interfaces, 11(1–2), 33–60. https://doi.org/10.1080/2326263X.2023.2287719
def _noise_frame(h=512, w=512, c=3):
    return (np.random.rand(h, w, c).astype(np.float32) * 2.0) - 1.0

def make_noise_stim(win):
    arr = _noise_frame()
    return visual.ImageStim(win, image=arr, size=win.size, interpolate=False)

def show_dynamic_noise(win, duration_s, clock, fps=60, overlay=None):
    noise = make_noise_stim(win)
    start = clock.getTime()
    frame_time = 1.0 / fps
    while (clock.getTime() - start) < duration_s:
        noise.image = _noise_frame()
        noise.draw()
        if overlay: overlay()
        win.flip()
        core.wait(frame_time / 2.0)
# ========================================================================

def show_images_sequence(win, img_stims, order, duration_s, clock, overlay=None):
    # Drain any key presses so space during the sequence can't carry over to the cue
    event.clearEvents(eventType='keyboard')

    for idx in order:
        img_stims[idx].draw()
        if overlay: overlay()
        win.flip()
        core.wait(duration_s)

    # extra safety: drain keys again at the end of the sequence
    event.clearEvents(eventType='keyboard')

def show_number_and_wait_space(win, number, text_stim, overlay=None): # prompt an image and wait until the participant is imagining it 
    text_stim.text = str(number)
    rt_clock = core.Clock()
    while True:
        text_stim.draw()
        if overlay: overlay()
        win.flip()
        keys = event.getKeys(keyList=['space', 'escape'])
        if 'escape' in keys:
            core.quit()
        if 'space' in keys:
            return rt_clock.getTime()

def show_black(win, duration_s, on_flip=None, clock=None, flash_photodiode=True):
    """Show a black screen for duration_s. Optionally flash a small white square
    in the corner for >=100 ms at onset (for photodiode)."""
    onset_holder = {"t_rel_s": None, "t_abs_dt": None}
    first_flip_done = [False]

    def _onflip_combo():
        # Time-lock the onset to the actual first flip
        # Capture absolute wall clock at the exact flip:
        onset_holder["t_abs_dt"] = datetime.datetime.now()
        # Capture relative experiment time at the exact flip:
        if clock is not None:
            onset_holder["t_rel_s"] = clock.getTime()
        # Then optionally fire the hardware trigger:
        if on_flip is not None:
            on_flip()
        first_flip_done[0] = True

    win.color = BG_COLOR

    # Render loop (so we can flash for an exact temporal window)
    frame_time = 1.0 / FPS
    flash_time_s = max(0.100, PHOTODIODE_FLASH_MS / 1000.0)
    t0 = core.Clock()
    flash_stim = _photodiode_square(win) if flash_photodiode else None

    while True:
        # Draw black background; (nothing to draw explicitly since Window has BG_COLOR)
        # Only draw the photodiode square for the first flash_time_s seconds
        t = t0.getTime()
        if flash_stim is not None and t < flash_time_s:
            flash_stim.draw()

        # Ensure the on-flip callback fires at the first flip only
        if not first_flip_done[0]:
            win.callOnFlip(_onflip_combo)

        win.flip()

        if t >= duration_s:
            break

        # Avoid tight loop
        core.wait(frame_time / 2.0)

    return onset_holder

def show_choice_8_images(win, img_paths, overlay=None):
    """
    Build a randomized 8-tile grid (4 originals + 4 mirrors) and return:
      - selected index in DISPLAY order
      - stims in DISPLAY order (for feedback overlay)
      - header TextStim
      - display_items: list of dicts, each with:
          {'id_idx': 0..3, 'flip': bool, 'path': str}
    """
    spacing_x = win.size[0] * 0.2
    square_size = min(win.size) * 0.25
    x_positions = [-1.5 * spacing_x, -0.5 * spacing_x, 0.5 * spacing_x, 1.5 * spacing_x]
    # Row placement (top / bottom)
    row_shift = win.size[1] * 0.05
    base_row_y = win.size[1] * 0.25
    top_y = base_row_y - row_shift
    bottom_y = -base_row_y - row_shift

    header = visual.TextStim(
        win,
        text="Which image did you imagine?",
        color=TEXT_COLOR,
        height=36,
        pos=(0, win.size[1] * 0.40),
        wrapWidth=win.size[0] * 0.9,
        alignText='center'
    )

    #  Build the 8 items (4 identities x {orig, mirror}) and RANDOMIZE order
    items = []
    for i in range(4):
        items.append({'id_idx': i, 'flip': False, 'path': img_paths[i]})
        items.append({'id_idx': i, 'flip': True,  'path': img_paths[i]})
    random.shuffle(items)

    # Lay out as 4 columns x 2 rows in shuffled order
    stims = []
    display_items = []
    for col in range(4):
        itm = items[col]
        stim = visual.ImageStim(win, image=itm['path'], size=(square_size, square_size),
                                pos=(x_positions[col], top_y), flipHoriz=itm['flip'])
        stims.append(stim); display_items.append(itm)
    for col in range(4):
        itm = items[4 + col]
        stim = visual.ImageStim(win, image=itm['path'], size=(square_size, square_size),
                                pos=(x_positions[col], bottom_y), flipHoriz=itm['flip'])
        stims.append(stim); display_items.append(itm)

    # --- CLICK HANDLING: require a fresh click that starts inside a tile ---
    mouse = event.Mouse(visible=True, win=win)
    # 1) Ensure no buttons are currently down (swallow carry-over clicks)
    while any(mouse.getPressed()):
        header.draw()
        for stim in stims: stim.draw()
        if overlay: overlay()
        win.flip()
    # 2) Now wait for a new click (transition)
    pressed_prev = (0,0,0)
    while True:
        header.draw()
        for stim in stims: stim.draw()
        if overlay: overlay()
        win.flip()

        pressed = mouse.getPressed()
        # Detect new left-click (down transition)
        if pressed[0] and not pressed_prev[0]:
            pos = mouse.getPos()
            # Only accept if the click STARTED inside a tile
            for i, stim in enumerate(stims):
                if stim.contains(pos):
                    # Wait briefly for button release to avoid double-triggers
                    while mouse.getPressed()[0]:
                        core.wait(0.01)
                    core.wait(0.1)
                    return i, stims, header, display_items
        pressed_prev = pressed

def parse_image_info(image_path):
    """
    Extract category name and image number from a file path like:
    '.../category_03/img2_0.83.jpg'
    """
    parts = image_path.replace("\\", "/").split("/")
    category = parts[-2]
    filename = parts[-1]
    if filename.startswith("img") and "_" in filename:
        image_number = int(filename[3:].split("_")[0])
    else:
        image_number = None
    return category, image_number

# -------------------- Feedback (practice only) --------------------
def show_feedback_highlight(win, stims, header, correct_idx, selected_idx, correct, overlay=None):
    """
    - If correct: header "Correct.", selected image framed GREEN.
    - If incorrect: header "Incorrect. The correct image is indicated below.",
      correct image framed GREEN and selected image framed RED.
    Swallows mouse presses so no clicks carry into the next screen.
    """
    header.text = "Correct." if correct else "Incorrect. The correct image is indicated below."
    def outline_for(stim, color):
        w, h = stim.size
        return visual.Rect(win, width=w + 10, height=h + 10, pos=stim.pos,
                           lineColor=color, fillColor=None, lineWidth=6)
    green_box = outline_for(stims[correct_idx], "green")
    red_box = outline_for(stims[selected_idx], "red") if not correct else None

    # Show timed feedback
    t0 = core.Clock()
    mouse = event.Mouse(visible=True, win=win)
    while t0.getTime() < FEEDBACK_DURATION:
        header.draw()
        for s in stims: s.draw()
        green_box.draw()
        if red_box: red_box.draw()
        if overlay: overlay()
        win.flip()

    # --- Swallow any clicks before leaving feedback ---
    while any(mouse.getPressed()):
        header.draw()
        for s in stims: s.draw()
        green_box.draw()
        if red_box: red_box.draw()
        if overlay: overlay()
        win.flip()
    core.wait(0.1)
    event.clearEvents()

# -------------------- Vividness Rating --------------------
def show_vividness_rating(win, labels):
    instruction = visual.TextStim(
        win,
        text="Rate the quality of your visual imagery:",
        color=TEXT_COLOR,
        height=32,
        pos=(0, win.size[1]*0.25),
        wrapWidth=win.size[0]*0.9,
        alignText='center'
    )

    box_w = win.size[0]*0.12
    box_h = 70
    spacing = win.size[0]*0.03
    total_w = 5*box_w + 4*spacing
    start_x = -total_w/2 + box_w/2

    boxes = []
    num_stims = []
    label_stims = []
    for i in range(5):
        x = start_x + i*(box_w + spacing)
        rect = visual.Rect(win, width=box_w, height=box_h, pos=(x, 0),
                           lineColor='white', fillColor=None, lineWidth=2)
        boxes.append(rect)
        num = visual.TextStim(win, text=str(i+1), color=TEXT_COLOR,
                              height=28, pos=(x, 0), alignText='center')
        num_stims.append(num)
        lab = visual.TextStim(win, text=labels[i], color=TEXT_COLOR,
                              height=16, pos=(x, -95), wrapWidth=box_w*1.3, alignText='center')
        label_stims.append(lab)

    prompt = visual.TextStim(win, text="Click a button (or press 1–5).", color=TEXT_COLOR, height=22, pos=(0, -160))
    mouse = event.Mouse(visible=True, win=win)
    rt_clock = core.Clock()

    while any(mouse.getPressed()):
        instruction.draw()
        for b in boxes: b.draw()
        for n in num_stims: n.draw()
        for l in label_stims: l.draw()
        prompt.draw()
        win.flip()

    while True:
        instruction.draw()
        for b in boxes: b.draw()
        for n in num_stims: n.draw()
        for l in label_stims: l.draw()
        prompt.draw()
        win.flip()

        keys = event.getKeys(keyList=['1','2','3','4','5','escape'])
        if 'escape' in keys:
            core.quit()
        for k in keys:
            if k in ['1','2','3','4','5']:
                return int(k), rt_clock.getTime()

        if mouse.getPressed()[0]:
            pos = mouse.getPos()
            for i, rect in enumerate(boxes):
                if rect.contains(pos):
                    while mouse.getPressed()[0]:
                        core.wait(0.01)
                    core.wait(0.05)
                    return i+1, rt_clock.getTime()

# -------------------- Simple button screens --------------------
def show_button_screen(win, message, button_text="Continue", draw_progress=None):
    """Centered message with a clickable button; also Enter/Space."""
    text = visual.TextStim(win, text=message, color=TEXT_COLOR, height=28,
                           wrapWidth=win.size[0]*0.85, pos=(0, 100), alignText='center')
    btn_rect = visual.Rect(win, width=320, height=60, pos=(0, -120),
                           lineColor='white', fillColor=None)
    btn_label = visual.TextStim(win, text=button_text, color=TEXT_COLOR, height=28, pos=(0, -120))
    mouse = event.Mouse(visible=True, win=win)
    while True:
        text.draw()
        btn_rect.draw(); btn_label.draw()
        if draw_progress: draw_progress()
        win.flip()
        keys = event.getKeys(keyList=['return','space','escape'])
        if 'escape' in keys:
            core.quit()
        if 'return' in keys or 'space' in keys:
            core.wait(0.15)
            return
        if mouse.getPressed()[0]:
            if btn_rect.contains(mouse.getPos()):
                while mouse.getPressed()[0]:
                    core.wait(0.01)
                core.wait(0.1)
                return

# -------------------- Instructions Screen --------------------
def show_instructions(win):
    instructions = (
        "In the following experiment, you will see a sequence of four images. "
        "Remember those images carefully.\n\n"
        "After the sequence, you will be prompted with a number from 1 to 4 that indicates "
        "one of the images in the sequence. Call that image to mind and imagine it as vividly as possible.\n\n"
        "As soon as you can “see” the image with your inner eyes, press the space bar. "
        "Hold that mental image during the black screen that follows.\n\n"
        "Then, you will be shown a selection of eight images and asked to choose the one you just imagined. "
        "This experiment will be repeated several times.\n\n"
        "Click *Continue to practice* to start."
    )
    text = visual.TextStim(win, text=instructions, color=TEXT_COLOR, height=28,
                           wrapWidth=win.size[0] * 0.8, pos=(0, 120), alignText='center')
    btn_rect = visual.Rect(win, width=320, height=60, pos=(0, -200), lineColor='white', fillColor=None)
    btn_label = visual.TextStim(win, text="Continue to practice", color=TEXT_COLOR, height=28, pos=(0, -200))
    mouse = event.Mouse(visible=True, win=win)
    while True:
        win.clearBuffer()
        text.draw()
        btn_rect.draw(); btn_label.draw()
        win.flip()
        keys = event.getKeys(keyList=['return', 'space', 'escape'])
        if 'escape' in keys:
            core.quit()
        if 'return' in keys or 'space' in keys:
            core.wait(0.15)
            return
        if mouse.getPressed()[0]:
            if btn_rect.contains(mouse.getPos()):
                core.wait(0.15)
                return

# -------------------- Trial Helper --------------------
def run_single_trial(
    win, clock, four_img_stims, chosen_paths, memorability_scores,
    order_first, num, text_stim, phase="main", give_feedback=False,
    trigger=None, experiment_clock=None, progress_fraction=None
):
    overlay = None
    if progress_fraction is not None:
        overlay = overlay_with_progress(win, progress_fraction)

    # 1) sequence
    show_images_sequence(win, four_img_stims, order_first, IMAGE_DURATION, clock, overlay=overlay)

    # 2) brief noise before cue
    show_dynamic_noise(win, NOISE1_DURATION, clock, fps=FPS, overlay=overlay)
    # ensure only presses during the cue are seen
    event.clearEvents(eventType='keyboard')

    # 3) cue -> space
    rt = show_number_and_wait_space(win, num, text_stim, overlay=overlay)

    # 4) black (trigger on flip)
    if trigger is None:
        def _noop(): pass
        onflip = _noop
    else:
        onflip = lambda: trigger.send_now(TRIGGER_CODE_BLACK_ONSET)

    onset_info = show_black(win, BLACK_DURATION, on_flip=onflip, clock=experiment_clock) # capture both rel and abs times at flip
    # Derive the requested timing fields
    t_rel_ms = int(round(onset_info["t_rel_s"] * 1000.0)) if onset_info["t_rel_s"] is not None else None
    t_abs_str = onset_info["t_abs_dt"].strftime("%H:%M:%S.%f")[:-3] if onset_info["t_abs_dt"] is not None else None # Format absolute time as HH:MM:SS.mmm

    # 5) 8-choice
    rt_clock = core.Clock()
    selected_idx, stims, header, display_items = show_choice_8_images(win, chosen_paths, overlay=overlay)
    choice_rt = rt_clock.getTime() # <-- selection RT captured

    # correctness (use stable indices instead of filename parsing)
    target_img_idx = order_first[num - 1]        # 0..3: which of the 4 images was cued
    sel_info = display_items[selected_idx]       # metadata for the clicked tile

    selected_id_idx = sel_info['id_idx']         # 0..3: which base image this tile represents
    selected_flip   = sel_info['flip']
    selected_path   = sel_info['path']

    # logging for later
    target_path = chosen_paths[target_img_idx]
    target_cat, target_img_num = parse_image_info(target_path)
    selected_cat, selected_img_num = parse_image_info(selected_path)

    # Correctness definitions (index-based, flip-aware)
    correct_match    = (selected_id_idx == target_img_idx and selected_flip is False)
    correct_identity = (selected_id_idx == target_img_idx)

    if give_feedback: # for practice trials only
        # Find which tile is the correct (original) target in the CURRENT randomized display
        correct_idx = None
        for i, itm in enumerate(display_items):
            if itm['id_idx'] == target_img_idx and itm['flip'] == False:
                correct_idx = i; break
        if correct_idx is None: correct_idx = 0
        show_feedback_highlight(win, stims, header, correct_idx, selected_idx, correct_match, overlay=overlay)

    return {
        "phase": phase, # whether practice or main 
        "order_first": order_first, # sequence of selected images
        "number_shown": num, # which image in the sequence was prompted
        "target_image_path": target_path, # prompted image = target 
        "target_category": target_cat,
        "target_image_number": target_img_num,
        "target_memorability": memorability_scores.get(target_path, "NA"), # mem score of prompted image
        "reaction_time": rt, # time between prompted number and pressing space to trigger black screen signaling having identified and now imagining prompted image
        "selected_image_path": selected_path, # referring to selected image in the 8 image grid
        "selected_category": selected_cat,
        "selected_image_number": selected_img_num,
        "selected_flipped": selected_flip,
        "target_id_idx": int(target_img_idx),
        "selected_id_idx": int(selected_id_idx),
        "choice_rt": choice_rt, # time to select one of the 8 images
        "correct_match": bool(correct_match), # whether the selected image is exactly the image prompted (taking into account orientation)
        "correct_identity": bool(correct_identity), # whether the selected image is the same as the image prompted NOT considering orientation (so can be mirror image)
        "black_onset_time": onset_info["t_rel_s"],  # seconds
        "black_onset_ms": t_rel_ms,                 # ms since experiment start
        "black_onset_abs_time": t_abs_str,          # HH:MM:SS.mmm
        "trial_stopped_on": None,                    # placeholder; will be set by caller
        "vividness_rating": None,
        "vividness_rt": None
    }


# -------------------- Main Experiment --------------------
def main():
    # ---------------- INITIALIZATION ---------------------
    participant_id = input("Enter participant ID: ").strip() # type participant ID into terminal 

    # --- Set up directory structure (save under data/<participant_id>/) ---
    script_dir = os.path.dirname(os.path.abspath(__file__))
    data_dir = os.path.join(script_dir, "data")
    participant_dir = os.path.join(data_dir, participant_id)
    os.makedirs(participant_dir, exist_ok=True)

    # --- Build output filenames (consistent base name) ---
    base_filename = f"{participant_id}_imagery_experiment"
    csv_filename = os.path.join(participant_dir, base_filename + ".csv")
    xlsx_filename = os.path.join(participant_dir, base_filename + ".xlsx")
    second_phase_csv = os.path.join(participant_dir, base_filename + "_sequence_phase.csv") # second-phase file path

    timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

    # PsychoPy window and clocks
    win = visual.Window(WIN_SIZE, fullscr=FULLSCREEN, color=BG_COLOR, units='pix')
    clock = core.Clock()
    experiment_clock = core.Clock()

    # Build trigger sender once and share it
    trigger = TriggerSender(backend=TRIGGER_BACKEND)

    # Instructions
    show_instructions(win)

    # ---- Load PRACTICE images
    practice_lib, practice_mem_scores = build_image_library(os.path.join(script_dir, PRACTICE_IMAGE_ROOT))
    if len(practice_lib) < 2:
        raise RuntimeError("Not enough practice categories found in 'images_practice'.")
    practice_chosen_paths = choose_four_images(practice_lib)
    practice_four_img_stims = build_image_stims(win, practice_chosen_paths)

    # ---- Load MAIN images
    image_lib, memorability_scores = build_image_library(os.path.join(script_dir, IMAGE_ROOT))
    if len(image_lib) < 2:
        raise RuntimeError("Not enough main-task categories found in 'images'.")
    chosen_paths = choose_four_images(image_lib)  # <- these 4 images are reused in second phase
    four_img_stims = build_image_stims(win, chosen_paths)

    # Prep main orders / numbers
    orders = counterbalance_positions(TRIALS, n_items=4)
    numbers = balanced_numbers(TRIALS, n_items=4)
    text_stim = visual.TextStim(win, text="", color=TEXT_COLOR, height=60)

    results = []

    # -------------------- PRACTICE (uses practice images) --------------------
    practice_orders = []
    for _ in range(PRACTICE_TRIALS):
        order = list(range(4))
        random.shuffle(order)
        practice_orders.append(order)
    practice_numbers = [random.randint(1, 4) for _ in range(PRACTICE_TRIALS)]

    for p in range(PRACTICE_TRIALS):
        res = run_single_trial(
            win=win, clock=clock,
            four_img_stims=practice_four_img_stims,
            chosen_paths=practice_chosen_paths,
            memorability_scores=practice_mem_scores,
            order_first=practice_orders[p],
            num=practice_numbers[p],
            text_stim=text_stim,
            phase="practice",
            give_feedback=True,
            trigger=trigger,
            experiment_clock=experiment_clock,
            progress_fraction=None
        )
        res["trial"] = p + 1
        rating, rrt = show_vividness_rating(win, VIVIDNESS_LABELS)
        res["vividness_rating"] = int(rating)
        res["vividness_rt"] = float(rrt)
        res["trial_stopped_on"] = None  # not applicable for practice
        results.append(res)

        if p < PRACTICE_TRIALS - 1:
            pause_text = visual.TextStim(win, text="Press space to continue.", height=40, color='white')
            pause_text.draw(); win.flip()
            keys = event.waitKeys(keyList=['space', 'escape'])
            if 'escape' in keys:
                confirm_text = visual.TextStim(win, text="Are you sure you want to quit? (y/n)", height=40, color='white')
                confirm_text.draw(); win.flip()
                confirm_keys = event.waitKeys(keyList=['y', 'n'])
                if 'y' in confirm_keys:
                    text_stim.text = "Experiment ended."
                    text_stim.draw(); win.flip(); core.wait(2.0); win.close()
                    return
            show_dynamic_noise(win, NOISE2_DURATION, clock, fps=FPS, overlay=None)

    # Start main message
    start_main = visual.TextStim(win, text="Practice complete. Press space to begin the main task.", height=40, color='white')
    start_main.draw(); win.flip()
    _ = event.waitKeys(keyList=['space', 'escape'])
    if 'escape' in _:
        text_stim.text = "Experiment ended."
        text_stim.draw(); win.flip(); core.wait(2.0); win.close()
        return

    # -------------------- MAIN (uses main images) --------------------
    interrupted = False
    trial_stopped_on = 0

    try:
        for t in range(TRIALS):
            order_first = orders[t]
            num = numbers[t]
            progress_fraction = float(t) / float(TRIALS)

            res = run_single_trial(
                win=win, clock=clock,
                four_img_stims=four_img_stims,
                chosen_paths=chosen_paths,
                memorability_scores=memorability_scores,
                order_first=order_first,
                num=num,
                text_stim=text_stim,
                phase="main",
                give_feedback=False, # no feedback during main
                trigger=trigger,
                experiment_clock=experiment_clock,
                progress_fraction=progress_fraction
            )
            res["trial"] = t + 1
            trial_stopped_on = t + 1
            res["trial_stopped_on"] = trial_stopped_on # <-- recorded per main trial

            if ((t + 1) % 4) == 0:
                rating, rrt = show_vividness_rating(win, VIVIDNESS_LABELS)
                res["vividness_rating"] = int(rating)
                res["vividness_rt"] = float(rrt)

            results.append(res)

            # Pause then inter-trial noise
            pause_fraction = float(t + 1) / float(TRIALS)
            pause_text = visual.TextStim(win, text="Press space to continue.", height=40, color='white')
            draw_progress_bar(win, pause_fraction); pause_text.draw(); win.flip()
            keys = event.waitKeys(keyList=['space', 'escape'])
            if 'escape' in keys:
                confirm_text = visual.TextStim(win, text="Are you sure you want to quit? (y/n)", height=40, color='white')
                draw_progress_bar(win, pause_fraction); confirm_text.draw(); win.flip()
                confirm_keys = event.waitKeys(keyList=['y', 'n'])
                if 'y' in confirm_keys:
                    interrupted = True
                    break

            if t < TRIALS - 1:
                show_dynamic_noise(win, NOISE2_DURATION, clock, fps=FPS,
                                   overlay=overlay_with_progress(win, pause_fraction))

        # Only proceed to second phase if main was NOT interrupted
        if not interrupted:
            # End-of-main screen with button
            show_button_screen(
                win,
                message="Thank you for participating.\n\nPlease press Continue to move on to the next experiment.",
                button_text="Continue"
            )

            # Second-phase instructions screen
            second_instructions = (
                "In the following, you will see a sequence of 4 images.\n"
                "You will be shown the sequence 5 times.\n\n"
                "Please remember both the images and their sequence as you will be asked to recall them later."
            )
            show_button_screen(win, message=second_instructions, button_text="Begin")

            # === Second phase — SAME four images, ONE random sequence, repeat 5x -----
            REPS = 5
            DELAY_BETWEEN_REPS_S = 5.0

            # Build a single random order once
            fixed_order = list(range(4))
            random.shuffle(fixed_order)

            # Store both the images used and the single sequence in an Excel file
            second_phase_xlsx = os.path.join(participant_dir, base_filename + "_second_phase_instructions.xlsx")

            # Record rows per repetition (for a "sequence" sheet)
            second_phase_rows = []
            # Also record an "images" sheet with positions->paths mapping for clarity
            images_rows = [{
                "participant_id": participant_id,
                "timestamp": datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"),
                "pos1_image_path": chosen_paths[fixed_order[0]],
                "pos2_image_path": chosen_paths[fixed_order[1]],
                "pos3_image_path": chosen_paths[fixed_order[2]],
                "pos4_image_path": chosen_paths[fixed_order[3]],
                "order_indices": ",".join(str(i) for i in fixed_order)  # indices into chosen_paths
            }]

            # Reuse stims from main
            for r in range(1, REPS + 1):
                # show the 4 images in the fixed order
                show_images_sequence(win, four_img_stims, fixed_order, IMAGE_DURATION, clock, overlay=None)

                # record this repetition
                second_phase_rows.append({
                    "participant_id": participant_id,
                    "timestamp": datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"),
                    "repetition": r,
                    "order_indices": ",".join(str(i) for i in fixed_order),
                    "img_pos1_path": chosen_paths[fixed_order[0]],
                    "img_pos2_path": chosen_paths[fixed_order[1]],
                    "img_pos3_path": chosen_paths[fixed_order[2]],
                    "img_pos4_path": chosen_paths[fixed_order[3]],
                })

                if r < REPS:
                    # show dynamic noise for the entire wait duration
                    show_dynamic_noise(win, DELAY_BETWEEN_REPS_S, clock, fps=FPS, overlay=None)

            # save the second-phase XLSX (separate from main)
            try:
                df_seq = pd.DataFrame(second_phase_rows)
                df_imgs = pd.DataFrame(images_rows)
                with pd.ExcelWriter(second_phase_xlsx, engine="openpyxl") as writer:
                    df_imgs.to_excel(writer, sheet_name="images_used_and_order", index=False)
                    df_seq.to_excel(writer, sheet_name="sequence_repetitions", index=False)
                print(f"✅ Second-phase sequence saved to:\n  {second_phase_xlsx}")
            except Exception:
                try:
                    with pd.ExcelWriter(second_phase_xlsx, engine="xlsxwriter") as writer:
                        df_imgs.to_excel(writer, sheet_name="images_used_and_order", index=False)
                        df_seq.to_excel(writer, sheet_name="sequence_repetitions", index=False)
                    print(f"✅ Second-phase sequence saved to:\n  {second_phase_xlsx}")
                except Exception as e:
                    # Fallback to CSVs if Excel write fails
                    second_phase_csv = os.path.splitext(second_phase_xlsx)[0] + ".csv"
                    df_seq.to_csv(second_phase_csv, index=False)
                    print(f"⚠️ Excel write failed ({e}). Wrote CSV instead:\n  {second_phase_csv}")

            # final thank-you screen for the whole session
            show_button_screen(win, message="Thank you!", button_text="Finish")

    except Exception as e:
        print("⚠️ Experiment crashed with error:", e)
        interrupted = True

    finally:
        # Closing screen (keep simple here; earlier flow already showed thanks)
        end_msg = "Experiment ended." if interrupted else "Thank you!"
        text_stim2 = visual.TextStim(win, text=end_msg, color=TEXT_COLOR, height=60)
        text_stim2.draw(); win.flip(); core.wait(2.0)
        win.close()

        if len(results) > 0:
            # ------ STORE RESULTS (CSV) ---------------------------------
            with open(csv_filename, "w", newline='') as f:
                writer = csv.DictWriter(f, fieldnames=results[0].keys())
                writer.writeheader()
                writer.writerows(results)

            # ------ STORE RESULTS (Excel with three sheets) -------------
            # Build DataFrames for Excel
            df = pd.DataFrame(results)
            df_main = df[df["phase"] == "main"].copy() # should ONLY consider actual experimental (main) trials

            # Counts
            total_trials_main = len(df_main)
            total_correct_match = int(df_main["correct_match"].sum()) if "correct_match" in df_main else 0
            total_correct_identity = int(df_main["correct_identity"].sum()) if "correct_identity" in df_main else 0
            total_semantic_match = total_correct_identity - total_correct_match

            # Mean cue→space RTs (reaction_time)
            mean_rt_all = float(df_main["reaction_time"].mean()) if "reaction_time" in df_main else float('nan')
            mean_rt_correct_match = float(df_main.loc[df_main["correct_match"] == True, "reaction_time"].mean()) if "reaction_time" in df_main else float('nan')
            mean_rt_correct_identity = float(df_main.loc[df_main["correct_identity"] == True, "reaction_time"].mean()) if "reaction_time" in df_main else float('nan')
            mean_rt_incorrect = float(df_main.loc[df_main["correct_identity"] == False, "reaction_time"].mean()) if "reaction_time" in df_main else float('nan')

            # Mean choice RTs (8-image selection)
            mean_choice_rt_all = float(df_main["choice_rt"].mean()) if "choice_rt" in df_main else float('nan')
            mean_choice_rt_correct_match = float(df_main.loc[df_main["correct_match"] == True, "choice_rt"].mean()) if "choice_rt" in df_main else float('nan')
            mean_choice_rt_correct_identity = float(df_main.loc[df_main["correct_identity"] == True, "choice_rt"].mean()) if "choice_rt" in df_main else float('nan')
            mean_choice_rt_incorrect = float(df_main.loc[df_main["correct_identity"] == False, "choice_rt"].mean()) if "choice_rt" in df_main else float('nan')

            # Mean memorability by outcome (target_memorability)
            mean_mem_correct_match = float(df_main.loc[df_main["correct_match"] == True, "target_memorability"].mean()) if "target_memorability" in df_main else float('nan')
            mean_mem_correct_identity = float(df_main.loc[df_main["correct_identity"] == True, "target_memorability"].mean()) if "target_memorability" in df_main else float('nan')
            mean_mem_incorrect = float(df_main.loc[df_main["correct_identity"] == False, "target_memorability"].mean()) if "target_memorability" in df_main else float('nan')

            # mean VVIQ
            if "vividness_rating" in df_main:
                mean_vividness_overall = float(pd.to_numeric(df_main["vividness_rating"], errors='coerce').mean())
            else:
                mean_vividness_overall = float('nan')

            summary_df = pd.DataFrame({
                "trials_completed": [total_trials_main], # how many main trials were completed
                "total_correct_matches": [total_correct_match], # how many times exactly the prompted image was identified in the grid
                "total_semantic_match": [total_semantic_match], # how many times the mirror image was selected in the grid

                "mean_reaction_time": [mean_rt_all], # mean time between prompt and pressing space
                "mean_rt_correct_matches": [mean_rt_correct_match], # mean reaction_time over trials with correct_match = True
                "mean_rt_correct_identity": [mean_rt_correct_identity], # mean reaction_time where correct_identity = True
                "mean_rt_incorrect_matches": [mean_rt_incorrect], # mean reaction_time where correct_identity = False

                "mean_choice_rt": [mean_choice_rt_all], # mean of choice_rt over all main trials
                "mean_choice_rt_correct_matches": [mean_choice_rt_correct_match], # mean choice_rt where correct_match = True
                "mean_choice_rt_correct_identity": [mean_choice_rt_correct_identity], # mean choice_rt where correct_identity = True
                "mean_choice_rt_incorrect_matches": [mean_choice_rt_incorrect], # mean choice_rt where correct_identity = False

                "mean_mem_correct_matches": [mean_mem_correct_match], # mean target_memorability where correct_match = True
                "mean_mem_correct_identity": [mean_mem_correct_identity], # mean target_memorability where correct_identity = True
                "mean_mem_incorrect_matches": [mean_mem_incorrect], # mean target_memorability where correct_identity = False

                "mean_vividness_rating_over_main": [mean_vividness_overall] # mean VVIQ score
            })

            # ---- black_onsets sheet
            # trial number, phase, black onset in ms, absolute time (HH:MM:SS.mmm)
            cols_needed = ["trial", "phase", "black_onset_ms", "black_onset_abs_time"]
            # Some rows (if any unexpected) might miss fields; ensure columns exist:
            for c in cols_needed:
                if c not in df.columns:
                    df[c] = None
            # Sort by phase then trial for readability
            black_df = df.loc[:, cols_needed].copy().sort_values(["phase", "trial"])

            # Write sheets; prefer openpyxl, fall back to xlsxwriter, else CSVs
            try:
                with pd.ExcelWriter(xlsx_filename, engine="openpyxl") as writer:
                    df.to_excel(writer, sheet_name="trials", index=False)
                    summary_df.to_excel(writer, sheet_name="summary", index=False)
                    black_df.to_excel(writer, sheet_name="black_onsets", index=False)
            except Exception:
                try:
                    with pd.ExcelWriter(xlsx_filename, engine="xlsxwriter") as writer:
                        df.to_excel(writer, sheet_name="trials", index=False)
                        summary_df.to_excel(writer, sheet_name="summary", index=False)
                        black_df.to_excel(writer, sheet_name="black_onsets", index=False)
                except Exception:
                    # CSV fallbacks match the same suffix
                    df.to_csv(os.path.join(participant_dir, base_filename + "_trials_fallback.csv"), index=False)
                    summary_df.to_csv(os.path.join(participant_dir, base_filename + "_summary_fallback.csv"), index=False)
                    black_df.to_csv(os.path.join(participant_dir, base_filename + "_black_onsets_fallback.csv"), index=False)
                    print("⚠️ Could not write Excel file (openpyxl/xlsxwriter missing). Saved CSVs instead.")

            print(f"\n✅ Results saved to:\n  {xlsx_filename}")

if __name__ == "__main__":
    main()
