class Solution:
def countTriplets(self, arr: List[int]) -> int:
prexor, cur, res = {0: [1, -1]}, 0, 0
for k, v in enumerate(arr):
cur ^= v
if cur not in prexor:
prexor[cur] = [0, 0]
f, t = prexor[cur]
res += (k - 1) * f - t
prexor[cur] = [f + 1, t + k]
return res