# Neural Correlates of Memory and Attention Study 2025
# Sustained Attention to Response Task (SART) Test
# Go/No-Go SART: press SPACE for digits 1,2,4,5,6,7,8,9; withhold for digit '3' 
# as described by: Robertson, I. H., Manly, T., Andrade, J., Baddeley, B. T., & Yiend, J. (1997). ‘Oops!’: Performance correlates of everyday attentional failures in traumatic brain injured and normal subjects. Neuropsychologia, 35(6), 747–758.

from psychopy import visual, event, core
import random
import pandas as pd
import datetime
import os
import sys

# ---------- Resolve script directory ----------
try:
    SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
except NameError:
    # __file__ can be missing in some environments; fall back safely
    SCRIPT_DIR = os.path.dirname(os.path.abspath(sys.argv[0])) if sys.argv and sys.argv[0] else os.getcwd()

# ---- Saving the data ----
def save_sart(trials_df, summary_df, base):
    xlsx = base + ".xlsx" #base is the full path without extension, e.g. /.../data/PID/PID_SART[-INTERRUPTED]
    try:
        try:
            with pd.ExcelWriter(xlsx, engine="openpyxl") as w:
                trials_df.to_excel(w, index=False, sheet_name="Trials")
                summary_df.to_excel(w, index=False, sheet_name="Summary")
        except Exception:
            with pd.ExcelWriter(xlsx, engine="xlsxwriter") as w:
                trials_df.to_excel(w, index=False, sheet_name="Trials")
                summary_df.to_excel(w, index=False, sheet_name="Summary")
        print(f"✅ Saved: {os.path.abspath(xlsx)}")
    except Exception as e:
        print(f"⚠️ Excel write failed ({e}). Writing CSV fallbacks.")
        trials_csv = base + "_Trials.csv"
        summary_csv = base + "_Summary.csv"
        trials_df.to_csv(trials_csv, index=False, encoding="utf-8")
        summary_df.to_csv(summary_csv, index=False, encoding="utf-8")
        print(f"💾 Saved: {os.path.abspath(trials_csv)}")
        print(f"💾 Saved: {os.path.abspath(summary_csv)}")

# ----- Prompting Participant ID in terminal ----
try:
    while True:
        pid = input("Enter Participant ID: ").strip()
        if pid:
            break
        print("Participant ID cannot be empty. Press Ctrl+C to abort.")
except (KeyboardInterrupt, EOFError):
    print("\nExiting.")
    core.quit()

# ---------- Build save directories ----------
DATA_DIR = os.path.join(SCRIPT_DIR, "data")
PARTICIPANT_DIR = os.path.join(DATA_DIR, pid)
os.makedirs(PARTICIPANT_DIR, exist_ok=True)

timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

# Task params
font_sizes = [48, 72, 94, 100, 120]
digits = list(range(1, 10)) # show digits from 1-9
total_trials = 225
target_digit = 3
digit_duration = 0.25 # each digit presented for 250ms
mask_duration = 0.9 # followed by a 900ms duration mask 
# experiment parameters as described by: Robertson, I. H., Manly, T., Andrade, J., Baddeley, B. T., & Yiend, J. (1997). ‘Oops!’: Performance correlates of everyday attentional failures in traumatic brain injured and normal subjects. Neuropsychologia, 35(6), 747–758.

# Window/stims
win = visual.Window(fullscr=True, color='black', units='pix')
text = visual.TextStim(win, text='', height=40, wrapWidth=1200, color='white')
mask = visual.TextStim(win, text='⊗', height=120, color='white')
clock = core.Clock()

# Prepare storage BEFORE helpers so we can save on early quit
trials = []

# ---------- Summary helpers ----------
def mean(lst):
    lst = [x for x in lst if x is not None]
    return round(sum(lst)/len(lst), 2) if lst else None

# computing summary variables to be reported alongside each trial outcomes
def compute_summary(trials_list):
    # Basic counts
    correct_suppressions = sum(1 for t in trials_list if t["go_or_nogo"]=="nogo" and not t["response"]) # correctly did not press SPACE when 3 was shown 
    incorrect_response  = sum(1 for t in trials_list if t["go_or_nogo"]=="nogo" and t["response"]) # incorrectly pressed SPACE when 3 was shown (later referred to as false alarm)
    omissions           = sum(1 for t in trials_list if t["go_or_nogo"]=="go" and not t["response"]) # incorrectly did not respond when they should have 
    anticipatory_count  = sum(1 for t in trials_list if t["anticipatory"]) # number of times a participant responded <100ms after stimulus, indicating a purely anticipatory response 
    valid_rt_count      = sum(1 for t in trials_list if t["valid_rt"]) # number of times a participant responded >=200ms after stimulus, indicating a true, non-anticipatory response

    # Pre-nogo means: mean of 4 *preceding* correct go RTs before a nogo trial
    # calculating the mean response time of the 4 go trials before a no-go trial 
    def four_prev_go_mean(i):
        if i < 4: 
            return None
        seg = trials_list[i-4:i]
        vals = [t["rt_ms"] for t in seg
                if t["go_or_nogo"]=="go"
                and t["correct_response"]
                and t["rt_ms"] is not None] # excluding trials without response
        return (sum(vals)/4.0) if len(vals)==4 else None

    # Post-nogo means: mean of the next 4 presses (any go trials with a press/rt)
    # calculating the mean response time of the 4 go trials after a no-go trial 
    def next_four_presses_mean(start_idx):
        rts = []
        j = start_idx + 1
        N = len(trials_list)
        while j < N and len(rts) < 4:
            t = trials_list[j]
            # count only actual presses with RTs
            if t.get("response") and (t.get("rt_ms") is not None):
                rts.append(t["rt_ms"])
            j += 1
        return (sum(rts)/4.0) if len(rts) == 4 else None

    pre_rt_correct = [] 
    pre_rt_incorrect = [] 
    post_means_after_correct_withhold = [] 
    post_means_after_false_alarm = [] 

    for i, tr in enumerate(trials_list):
        if tr["go_or_nogo"] != "nogo": 
            continue # only iterating over no-go trials 

        # Pre-nogo sets
        m = four_prev_go_mean(i)
        if m is not None:
            if tr["response"]:  # false alarm
                pre_rt_incorrect.append(m)
            else:               # correct withhold
                pre_rt_correct.append(m)

        # Post-nogo sets
        m_post = next_four_presses_mean(i)
        if m_post is not None:
            if tr["response"]:  # false alarm
                post_means_after_false_alarm.append(m_post)
            else:               # correct withhold
                post_means_after_correct_withhold.append(m_post)

    return [{
        "test": "SART_Summary",
        "participant_id": pid,
        "timestamp": timestamp,
        "correct_suppressions": correct_suppressions, # correctly did not press SPACE when 3 was shown 
        "incorrect_go_responses": incorrect_response, # incorrectly pressed SPACE when 3 was shown 
        "omission_errors": omissions, # incorrectly did not respond when they should have 
        "anticipatory_responses": anticipatory_count, # number of times a participant responded <100ms after stimulus, indicating a purely anticipatory response 
        "valid_responses": valid_rt_count, # number of times a participant responded >=200ms after stimulus, indicating a true, non-anticipatory response
        "mean_pre_rt_correct_suppression": mean(pre_rt_correct), # the mean response time before a correct no-press no-go trial
        "mean_pre_rt_incorrect_suppression": mean(pre_rt_incorrect), # the mean response time before an incorrect did-press no-go trial
        "mean_post_rt_after_correct_withhold": mean(post_means_after_correct_withhold), # the mean response time after a correct no-press no-go trial
        "mean_post_rt_after_false_alarm": mean(post_means_after_false_alarm), # the mean response time after an incorrect did-press no-go trial
    }]

# ---- Storing data in case experiment is interrupted ----
def save_and_quit(interrupted=False):
    """Save current data and exit. If interrupted=True, append -INTERRUPTED to filename."""
    suffix = "_INTERRUPTED" if interrupted else ""
    base = os.path.join(PARTICIPANT_DIR, f"{pid}_SART{suffix}")
    trials_df = pd.DataFrame(trials)
    summary_df = pd.DataFrame(compute_summary(trials))
    save_sart(trials_df, summary_df, base)
    try:
        win.close()
    finally:
        core.quit()

# ---------- Quit helpers ----------
def prompt_quit():
    """Draw a modal confirm overlay and return True if user presses 'y'."""
    # Dim background
    try:
        dim = visual.ShapeStim(
            win,
            vertices=[(-win.size[0]/2, -win.size[1]/2),
                      ( win.size[0]/2, -win.size[1]/2),
                      ( win.size[0]/2,  win.size[1]/2),
                      (-win.size[0]/2,  win.size[1]/2)],
            closeShape=True, units='pix',
            fillColor=[-0.5, -0.5, -0.5],
            lineColor=None, opacity=0.6
        )
        dim.draw()
    except Exception:
        pass

    msg = visual.TextStim(
        win,
        text="Do you really want to quit?\n\nPress Y to confirm, N to continue.",
        height=32, color='white', wrapWidth=1000
    )
    msg.draw()
    win.flip()
    key = event.waitKeys(keyList=['y', 'n', 'Y', 'N'])[0]
    event.clearEvents()
    return key.lower() == 'y'

def check_escape_pressed_and_maybe_quit():
    """Check for Esc; if pressed and confirmed, SAVE partial data and quit."""
    if 'escape' in event.getKeys(keyList=['escape']):
        if prompt_quit():
            save_and_quit(interrupted=True)
        else:
            event.clearEvents()
            return True  # Esc was pressed but user chose to continue
    return False

# ---------- Progress bar (ShapeStim-based) ----------
def draw_progress_bar(current_index, total_trials):
    # Size & position
    W, H = win.size
    bar_w = int(min(W * 0.7, 1000))
    bar_h = 16
    y = -H / 2 + 30

    # Bar coordinates
    x0 = -bar_w / 2
    x1 =  bar_w / 2
    y0 = y - bar_h / 2
    y1 = y + bar_h / 2

    # Track (outline)
    track = visual.ShapeStim(
        win,
        vertices=[(x0, y0), (x1, y0), (x1, y1), (x0, y1)],
        closeShape=True,
        units='pix',
        lineColor='white',
        lineWidth=2,
        fillColor=None
    )
    track.draw()

    # Fill according to progress
    pct = max(0.0, min(1.0, current_index / float(total_trials)))
    fill_w = max(2, int(bar_w * pct))
    fx1 = x0 + fill_w

    fill = visual.ShapeStim(
        win,
        vertices=[(x0, y0), (fx1, y0), (fx1, y1), (x0, y1)],
        closeShape=True,
        units='pix',
        lineColor=None,
        fillColor='white'
    )
    fill.draw()

# ---------- Instructions Screen ----------
title = visual.TextStim(win, text="Sustained Attention to Response Task Experiment",
                        height=44, color='white', bold=True, pos=(0, 120), wrapWidth=1200)
instructions = visual.TextStim(win,
    text="In the following experiment, please press SPACE for each digit, except when the digit '3' appears.\n"
         "Respond as quickly and accurately as possible.\n\n"
         "(Press SPACE to start)",
    height=32, color='white', wrapWidth=1200, pos=(0, -40))
title.draw()
instructions.draw()
win.flip()

k = event.waitKeys(keyList=['space', 'escape'])[0]
if k == 'escape':
    if prompt_quit():
        save_and_quit(interrupted=True)

# ---------- Pre-trial mask + progress (interruptible) ----------
t_pre = clock.getTime()
while (clock.getTime() - t_pre) < 1.0:
    mask.color = 'white'
    mask.draw()
    draw_progress_bar(0, total_trials)
    win.flip()
    if check_escape_pressed_and_maybe_quit():
        pass  # user chose to continue

# ---------- Trial lists ----------
digit_list = [d for d in digits for _ in range(25)]          # 9 digits × 25 = 225
random.shuffle(digit_list)
font_list = font_sizes * 45                                   # 5 sizes × 45 = 225
random.shuffle(font_list)

# ---------- Trials ----------
for trial_num in range(total_trials):
    digit = digit_list[trial_num]
    font_size = font_list[trial_num]
    digit_stim = visual.TextStim(win, text=str(digit), height=font_size, color='white')

    # Reset state
    clock.reset()
    rt = None
    response = False
    mask_color = 'white'
    post_flash_color = None

    # Digit phase (250 ms)
    t0 = clock.getTime()
    while (clock.getTime() - t0) < digit_duration:
        digit_stim.draw()
        draw_progress_bar(trial_num, total_trials)
        win.flip()

        # Keys: space/escape
        keys = event.getKeys(keyList=['space', 'escape'], timeStamped=clock)
        for key_name, ts in keys:
            if key_name == 'escape':
                if prompt_quit():
                    save_and_quit(interrupted=True)
                else:
                    event.clearEvents()
                    continue
            if key_name == 'space' and not response:
                response = True
                rt = ts * 1000.0
                mask_color = 'red' if digit == target_digit else 'green'

    # Mask phase (900 ms)
    t1 = clock.getTime()
    while (clock.getTime() - t1) < mask_duration:
        mask.color = mask_color
        mask.draw()
        draw_progress_bar(trial_num, total_trials)
        win.flip()

        keys = event.getKeys(keyList=['space', 'escape'], timeStamped=clock)
        for key_name, ts in keys:
            if key_name == 'escape':
                if prompt_quit():
                    save_and_quit(interrupted=True)
                else:
                    event.clearEvents()
                    continue
            if key_name == 'space' and not response:
                response = True
                rt = ts * 1000.0
                mask_color = 'red' if digit == target_digit else 'green'

    # Outcome
    anticipatory = rt is not None and rt < 100 # if a response to a stimulus is faster than 100ms, it is considered anticipatory
    valid_rt = rt is not None and rt >= 200  # if a response to a stimulus follows >=200ms, it is considered to be a true, non-anticipatory response
    # anticipatory/valid response times as described by: Cheyne, J.A., Solman, G.J.F., Carriere, J.S.A., & Smilek, D. (2009). Anatomy of an error: A bidirectional state model of task engagement/disengagement and attention-related errors. Cognition, 111, 98–113.
    
    within_window = rt is not None and rt <= 1150 # sum of digit and mask duration
    go_or_nogo = 'nogo' if digit == target_digit else 'go'
    correct = (digit != target_digit and response) or (digit == target_digit and not response)

    if not response:
        post_flash_color = 'green' if digit == target_digit else 'red'

    # Optional post flash (250 ms) — also allow Esc
    if post_flash_color is not None:
        t2 = clock.getTime()
        while (clock.getTime() - t2) < 0.25:
            mask.color = post_flash_color
            mask.draw()
            draw_progress_bar(trial_num + 1, total_trials)
            win.flip()
            if check_escape_pressed_and_maybe_quit():
                pass

    trials.append({
        "test": "SART",
        "participant_id": pid,
        "timestamp": timestamp,
        "trial": trial_num + 1,
        "digit": digit,
        "go_or_nogo": go_or_nogo,
        "font_size": font_size,
        "response": response,
        "rt_ms": rt,
        "anticipatory": anticipatory,
        "valid_rt": valid_rt,
        "within_time_window": within_window,
        "correct_response": correct
    })

# ---------- Save (completed run) ----------
trials_df = pd.DataFrame(trials)
summary_df = pd.DataFrame(compute_summary(trials))
final_base = os.path.join(PARTICIPANT_DIR, f"{pid}_SART")
save_sart(trials_df, summary_df, final_base)

# ---------- Goodbye ----------
text.text = "Thank you!\nYou may now close this window."
text.draw(); win.flip()
k = event.waitKeys(keyList=['space', 'escape'])
if 'escape' in k:
    if prompt_quit():
        # Even though run is complete, honor interrupted naming if they quit here
        save_and_quit(interrupted=True)

win.close()
core.quit()
