Skip to content
8 changes: 8 additions & 0 deletions Shoko.Server/Providers/TraktTV/TraktConstants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@ public enum TraktSyncType
HistoryRemove = 2
}

public enum TraktAuthTokenValidationResult
{
Valid = 1,
Invalid = 2,
Unknown = 3
}

public static class TraktStatusCodes
{
// http://docs.trakt.apiary.io/#introduction/status-codes
Expand All @@ -34,6 +41,7 @@ public static class TraktStatusCodes
public const int Conflict = 409;

public const int Precondition_Failed = 412;
public const int Denied = 418;
public const int Account_Limit_Exceeded = 420;
public const int Account_Locked = 423;
public const int Unprocessable_Entity = 422;
Expand Down
172 changes: 136 additions & 36 deletions Shoko.Server/Providers/TraktTV/TraktTVHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
using Shoko.Abstractions.Extensions;
using Shoko.Abstractions.User.Services;
using Shoko.Server.Models.Shoko;
Expand Down Expand Up @@ -36,6 +37,34 @@ public TraktTVHelper(ILogger<TraktTVHelper> logger, ISettingsProvider settingsPr

#region Helpers

private static bool IsInvalidGrantResponse(string response)
{
if (string.IsNullOrWhiteSpace(response))
return false;

try
{
var json = JObject.Parse(response);
return string.Equals(
(string?)json["error"],
"invalid_grant",
StringComparison.OrdinalIgnoreCase);
}
catch
{
return false;
}
}

private static bool IsExpectedDeviceTokenPollingStatus(HttpStatusCode statusCode)
=> (int)statusCode is
TraktStatusCodes.Awaiting_Auth or
TraktStatusCodes.Not_Found or
TraktStatusCodes.Conflict or
TraktStatusCodes.Token_Expired or
TraktStatusCodes.Denied or
TraktStatusCodes.Rate_Limit_Exceeded;

private int SendData(string uri, string json, string verb, Dictionary<string, string> headers, ref string webResponse)
{
var ret = 400;
Expand All @@ -59,29 +88,23 @@ private int SendData(string uri, string json, string verb, Dictionary<string, st
}

// post to trakt
var postStream = request.GetRequestStream();
using var postStream = request.GetRequestStream();
postStream.Write(data, 0, data.Length);

// get the response
var response = (HttpWebResponse)request.GetResponse();
using var response = (HttpWebResponse)request.GetResponse();
using var responseStream = response.GetResponseStream();

var responseStream = response.GetResponseStream();
if (responseStream == null)
{
return ret;
}

var reader = new StreamReader(responseStream);
using var reader = new StreamReader(responseStream);
var strResponse = reader.ReadToEnd();

var statusCode = (int)response.StatusCode;

// cleanup
postStream.Close();
responseStream.Close();
reader.Close();
response.Close();

webResponse = strResponse;
_logger.LogTrace("Trakt SEND Data - Response\nStatus Code: {StatusCode}\nResponse: {Response}", statusCode,
strResponse);
Expand All @@ -90,33 +113,46 @@ private int SendData(string uri, string json, string verb, Dictionary<string, st
}
catch (WebException webEx)
{
if (webEx.Status == WebExceptionStatus.ProtocolError)
if (webEx.Status == WebExceptionStatus.ProtocolError &&
webEx.Response is HttpWebResponse response)
{
if (webEx.Response is HttpWebResponse response)
if (response.ResponseUri.AbsoluteUri != TraktURIs.OAuthDeviceToken && response.StatusCode == HttpStatusCode.BadRequest)
{
{
_logger.LogError(webEx, "Error in SendData: {StatusCode}", (int)response.StatusCode);
ret = (int)response.StatusCode;
}
try
{
var responseStream2 = response.GetResponseStream();
if (responseStream2 == null)
{
return ret;
}
ret = (int)response.StatusCode;

var reader2 = new StreamReader(responseStream2);
webResponse = reader2.ReadToEnd();
_logger.LogError("Error in SendData: {Response}", webResponse);
}
catch
{
// ignore
}
try
{
using var responseStream = response.GetResponseStream();
if (responseStream is not null)
{
using var reader = new StreamReader(responseStream);
webResponse = reader.ReadToEnd();
}
}
catch
{
// ignore response body read failures
}

var isDeviceTokenPolling = response.ResponseUri.AbsoluteUri == TraktURIs.OAuthDeviceToken;
var isInvalidGrant = response.StatusCode == HttpStatusCode.BadRequest &&
IsInvalidGrantResponse(webResponse);

if (isInvalidGrant)
{
_logger.LogWarning(
"Trakt OAuth token request failed with invalid_grant. The token is invalid, expired, or revoked and must be re-authenticated.");
}
else if (!isDeviceTokenPolling || !IsExpectedDeviceTokenPollingStatus(response.StatusCode))
{
_logger.LogError(
webEx,
"Error in SendData: {StatusCode}. Response: {Response}",
ret,
webResponse);
}

return ret;
}

if (webEx.Response != null && webEx.Response.ResponseUri.AbsoluteUri != TraktURIs.OAuthDeviceToken)
{
_logger.LogError(webEx, "{Ex}", webEx.ToString());
Expand Down Expand Up @@ -206,9 +242,54 @@ private Dictionary<string, string> BuildRequestHeaders()

#region Authorization

public TraktAuthTokenValidationResult ValidateAuthToken()
{
var request = (HttpWebRequest)WebRequest.Create(TraktURIs.UserSettings);

_logger.LogTrace("Trakt token validation\nuri: {Uri}", TraktURIs.UserSettings);

request.KeepAlive = true;
request.Method = "GET";
request.ContentLength = 0;
request.Timeout = 120000;
request.ContentType = "application/json";
request.UserAgent = "JMM";
foreach (var header in BuildRequestHeaders())
{
request.Headers.Add(header.Key, header.Value);
}

try
{
using var response = (HttpWebResponse)request.GetResponse();
return (int)response.StatusCode == TraktStatusCodes.Success
? TraktAuthTokenValidationResult.Valid
: TraktAuthTokenValidationResult.Unknown;
}
catch (WebException ex) when (ex.Response is HttpWebResponse response)
{
var statusCode = (int)response.StatusCode;
if (statusCode is TraktStatusCodes.Unauthorized or TraktStatusCodes.Forbidden)
{
_logger.LogWarning("Trakt auth token validation failed with {StatusCode}.", statusCode);
return TraktAuthTokenValidationResult.Invalid;
}

_logger.LogError(ex, "Error validating Trakt auth token: {StatusCode}", statusCode);
return TraktAuthTokenValidationResult.Unknown;
}
catch (Exception ex)
{
_logger.LogError(ex, "Error validating Trakt auth token");
return TraktAuthTokenValidationResult.Unknown;
}
}

public bool RefreshAuthToken()
{
var settings = _settingsProvider.GetSettings();
var shouldSaveSettings = false;

try
{
if (!settings.TraktTv.Enabled ||
Expand All @@ -218,6 +299,8 @@ public bool RefreshAuthToken()
settings.TraktTv.AuthToken = string.Empty;
settings.TraktTv.RefreshToken = string.Empty;
settings.TraktTv.TokenExpirationDate = string.Empty;
settings.TraktTv.Enabled = false;
shouldSaveSettings = true;

return false;
}
Expand All @@ -229,11 +312,11 @@ public bool RefreshAuthToken()
var retData = string.Empty;
TraktTVRateLimiter.Instance.EnsureRate();
var response = SendData(TraktURIs.Oauth, json, "POST", headers, ref retData);

if (response is TraktStatusCodes.Success or TraktStatusCodes.Success_Post)
{
var loginResponse = retData.FromJSON<TraktAuthToken>();

// save the token to the config file to use for subsequent API calls
settings.TraktTv.AuthToken = loginResponse.AccessToken;
settings.TraktTv.RefreshToken = loginResponse.RefreshToken;

Expand All @@ -242,28 +325,45 @@ public bool RefreshAuthToken()
var expireDate = createdAt + validity;

settings.TraktTv.TokenExpirationDate = expireDate.ToString();
shouldSaveSettings = true;

return true;
}

if (IsInvalidGrantResponse(retData))
{
_logger.LogWarning("Trakt refresh token is invalid, expired, or revoked. Disabling Trakt until it is re-authenticated.");
settings.TraktTv.Enabled = false;
}
else
{
_logger.LogWarning("Failed to refresh Trakt auth token. Response code: {ResponseCode}. Response data: {ResponseData}", response, retData);
}

settings.TraktTv.AuthToken = string.Empty;
settings.TraktTv.RefreshToken = string.Empty;
settings.TraktTv.TokenExpirationDate = string.Empty;
shouldSaveSettings = true;

return false;
}
catch (Exception ex)
{
settings.TraktTv.AuthToken = string.Empty;
settings.TraktTv.RefreshToken = string.Empty;
settings.TraktTv.TokenExpirationDate = string.Empty;
shouldSaveSettings = true;

_logger.LogError(ex, "Error in TraktTVHelper.RefreshAuthToken");
return false;
}
finally
{
Utils.SettingsProvider.SaveSettings();
if (shouldSaveSettings)
{
_settingsProvider.SaveSettings();
}
}
return false;
}

#endregion
Expand Down
1 change: 1 addition & 0 deletions Shoko.Server/Providers/TraktTV/TraktURIs.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ public static class TraktURIs

public const string OAuthDeviceCode = TraktConstants.BaseAPIURL + @"/oauth/device/code";
public const string OAuthDeviceToken = TraktConstants.BaseAPIURL + @"/oauth/device/token";
public const string UserSettings = TraktConstants.BaseAPIURL + @"/users/settings";

// add to history (mark as watched)
// used for movies, series, episodes
Expand Down
18 changes: 18 additions & 0 deletions Shoko.Server/Scheduling/Jobs/Trakt/CheckTraktTokenJob.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,24 @@ public override Task Process()
return Task.CompletedTask;
}

var validationResult = _traktHelper.ValidateAuthToken();
if (validationResult == TraktAuthTokenValidationResult.Invalid)
{
_logger.LogInformation("Trakt auth token is no longer valid. Refreshing token.");
if (_traktHelper.RefreshAuthToken())
{
var newExpirationDate = DateTimeOffset.FromUnixTimeSeconds(long.Parse(settings.TraktTv.TokenExpirationDate)).DateTime;
_logger.LogInformation("Trakt token refreshed successfully. New expiry date: {Date}", newExpirationDate);
}

return Task.CompletedTask;
}

if (validationResult == TraktAuthTokenValidationResult.Unknown)
{
_logger.LogWarning("Unable to validate Trakt auth token. Falling back to stored expiry date.");
}

// Convert the Unix timestamp to DateTime
var expirationDate = DateTimeOffset.FromUnixTimeSeconds(long.Parse(settings.TraktTv.TokenExpirationDate)).DateTime;

Expand Down
26 changes: 24 additions & 2 deletions Shoko.Server/Scheduling/QuartzStartup.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,22 @@ public static class QuartzStartup
{
public static async Task ScheduleRecurringJobs(bool replace)
{
var settings = Utils.SettingsProvider.GetSettings();

// this needs to run immediately upon scheduling, so it replaces always. Others will run on other schedules
// Also give it a high priority, since it affects Acquisition Filters
// StartJobNow gives a priority of 10. We'll give it 20 to be even higher priority
await ScheduleRecurringJob<CheckNetworkAvailabilityJob>(
triggerConfig: t => t.WithPriority(20).WithSimpleSchedule(tr => tr.WithIntervalInMinutes(30).RepeatForever()).StartNow(), replace: true, keepSchedule: false);
await ScheduleRecurringJob<CheckTraktTokenJob>(
triggerConfig: t => t.WithPriority(20).WithSimpleSchedule(tr => tr.WithIntervalInMinutes(60).RepeatForever()).StartNow(), replace: true, keepSchedule: false);
if (settings.TraktTv.Enabled)
{
await ScheduleRecurringJob<CheckTraktTokenJob>(
triggerConfig: t => t.WithPriority(20).WithSimpleSchedule(tr => tr.WithIntervalInMinutes(60).RepeatForever()).StartNow(), replace: true, keepSchedule: false);
}
else
{
await RemoveRecurringJob<CheckTraktTokenJob>();
}

// TODO the other schedule-based jobs that are on timers
}
Expand Down Expand Up @@ -87,6 +96,19 @@ await scheduler.ScheduleJob(JobBuilder<T>.Create().UsingJobData(jobConfig).WithG
}
}

private static async Task RemoveRecurringJob<T>() where T : class, IJob
{
var groupName = typeof(T).GetCustomAttribute<JobKeyGroupAttribute>()?.GroupName;
var jobKey = JobKeyBuilder<T>.Create().WithGroup(groupName).Build();
var scheduler = await Utils.ServiceContainer.GetRequiredService<ISchedulerFactory>().GetScheduler();

using var _ = await QuartzExtensions.SchedulerLock.WriterLockAsync();
if (await scheduler.CheckExists(jobKey))
{
await scheduler.DeleteJob(jobKey);
}
}

internal static void AddQuartz(this IServiceCollection services, ISystemService systemService)
{
// this lets us inject the shoko JobFactory explicitly, instead of only IJobFactory
Expand Down
Loading