Add IUpdateSource; refactor UpdateManager so update package retrieval is no longer hard-coded.

This commit is contained in:
Caelan Sayler
2022-03-07 19:25:07 +00:00
parent 4895a219b9
commit 65fcbc2fc4
19 changed files with 957 additions and 780 deletions

View File

@@ -1,261 +0,0 @@
using System;
using System.IO;
using System.Net;
using System.Net.Http;
using System.Net.Http.Headers;
using System.Threading;
using System.Threading.Tasks;
using Squirrel.SimpleSplat;
namespace Squirrel
{
/// <summary>
/// A simple abstractable file downloader
/// </summary>
public interface IFileDownloader
{
/// <summary>
/// Downloads a remote file to the specified local path
/// </summary>
/// <param name="url">The url which will be downloaded.</param>
/// <param name="targetFile">
/// The local path where the file will be stored
/// If a file exists at this path, it will be overritten.</param>
/// <param name="progress">
/// A delegate for reporting download progress, with expected values from 0-100.
/// </param>
/// <param name="authorization">
/// Text to be sent in the 'Authorization' header of the request.
/// </param>
Task DownloadFile(string url, string targetFile, Action<int> progress, string authorization = null);
/// <summary>
/// Returns a byte array containing the contents of the file at the specified url
/// </summary>
Task<byte[]> DownloadBytes(string url, string authorization = null);
/// <summary>
/// Returns a string containing the contents of the specified url
/// </summary>
Task<string> DownloadString(string url, string authorization = null);
}
/// <inheritdoc cref="IFileDownloader"/>
public class HttpClientFileDownloader : IFileDownloader
{
/// <summary>
/// The User-Agent sent with Squirrel requests
/// </summary>
public static ProductInfoHeaderValue UserAgent => new("Squirrel", AssemblyRuntimeInfo.ExecutingAssemblyName.Version.ToString());
/// <inheritdoc />
public virtual async Task DownloadFile(string url, string targetFile, Action<int> progress, string authorization)
{
using var client = CreateHttpClient(authorization);
try {
using (var fs = File.Open(targetFile, FileMode.Create)) {
await DownloadToStreamInternal(client, url, fs, progress).ConfigureAwait(false);
}
} catch {
// NB: Some super brain-dead services are case-sensitive yet
// corrupt case on upload. I can't even.
using (var fs = File.Open(targetFile, FileMode.Create)) {
await DownloadToStreamInternal(client, url.ToLower(), fs, progress).ConfigureAwait(false);
}
}
}
/// <inheritdoc />
public virtual async Task<byte[]> DownloadBytes(string url, string authorization)
{
using var client = CreateHttpClient(authorization);
try {
return await client.GetByteArrayAsync(url).ConfigureAwait(false);
} catch {
// NB: Some super brain-dead services are case-sensitive yet
// corrupt case on upload. I can't even.
return await client.GetByteArrayAsync(url.ToLower()).ConfigureAwait(false);
}
}
/// <inheritdoc />
public virtual async Task<string> DownloadString(string url, string authorization)
{
using var client = CreateHttpClient(authorization);
try {
return await client.GetStringAsync(url).ConfigureAwait(false);
} catch {
// NB: Some super brain-dead services are case-sensitive yet
// corrupt case on upload. I can't even.
return await client.GetStringAsync(url.ToLower()).ConfigureAwait(false);
}
}
/// <summary>
/// Asynchronously downloads a remote url to the specified destination stream while
/// providing progress updates.
/// </summary>
protected virtual async Task DownloadToStreamInternal(HttpClient client, string requestUri, Stream destination, Action<int> progress = null, CancellationToken cancellationToken = default)
{
// https://stackoverflow.com/a/46497896/184746
// Get the http headers first to examine the content length
using var response = await client.GetAsync(requestUri, HttpCompletionOption.ResponseHeadersRead).ConfigureAwait(false);
response.EnsureSuccessStatusCode();
var contentLength = response.Content.Headers.ContentLength;
using var download = await response.Content.ReadAsStreamAsync().ConfigureAwait(false);
// Ignore progress reporting when no progress reporter was
// passed or when the content length is unknown
if (progress == null || !contentLength.HasValue) {
await download.CopyToAsync(destination).ConfigureAwait(false);
return;
}
var buffer = new byte[81920];
long totalBytesRead = 0;
int bytesRead;
int lastProgress = 0;
while ((bytesRead = await download.ReadAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false)) != 0) {
await destination.WriteAsync(buffer, 0, bytesRead, cancellationToken).ConfigureAwait(false);
totalBytesRead += bytesRead;
// Convert absolute progress (bytes downloaded) into relative progress (0% - 100%)
// and don't report progress < 3% difference, kind of like a shitty debounce.
var curProgress = (int) ((double) totalBytesRead / contentLength.Value * 100);
if (curProgress - lastProgress >= 3) {
lastProgress = curProgress;
progress(curProgress);
}
}
if (lastProgress != 100)
progress(100);
}
/// <summary>
/// Creates a new <see cref="HttpClient"/> for every request. Override this
/// function to add a custom proxy or other http configuration.
/// </summary>
protected virtual HttpClient CreateHttpClient(string authorization)
{
var handler = new HttpClientHandler() {
AllowAutoRedirect = true,
MaxAutomaticRedirections = 10,
AutomaticDecompression = DecompressionMethods.GZip | DecompressionMethods.Deflate,
};
var client = new HttpClient(handler, true);
client.DefaultRequestHeaders.UserAgent.Add(UserAgent);
if (authorization != null)
client.DefaultRequestHeaders.Add("Authorization", authorization);
return client;
}
}
/// <inheritdoc cref="IFileDownloader"/>
[Obsolete("Use HttpClientFileDownloader")]
public class FileDownloader : IFileDownloader, IEnableLogger
{
/// <inheritdoc />
public virtual async Task DownloadFile(string url, string targetFile, Action<int> progress, string authorization)
{
using (var wc = CreateWebClient()) {
var failedUrl = default(string);
wc.Headers.Add("Authorization", authorization);
var lastSignalled = DateTime.MinValue;
wc.DownloadProgressChanged += (sender, args) => {
var now = DateTime.Now;
if (now - lastSignalled > TimeSpan.FromMilliseconds(500)) {
lastSignalled = now;
progress(args.ProgressPercentage);
}
};
retry:
try {
this.Log().Info("Downloading file: " + (failedUrl ?? url));
await this.WarnIfThrows(
async () => {
await wc.DownloadFileTaskAsync(failedUrl ?? url, targetFile).ConfigureAwait(false);
progress(100);
},
"Failed downloading URL: " + (failedUrl ?? url)).ConfigureAwait(false);
} catch (Exception) {
// NB: Some super brain-dead services are case-sensitive yet
// corrupt case on upload. I can't even.
if (failedUrl != null) throw;
failedUrl = url.ToLower();
progress(0);
goto retry;
}
}
}
/// <inheritdoc />
public virtual async Task<byte[]> DownloadBytes(string url, string authorization)
{
using (var wc = CreateWebClient()) {
var failedUrl = default(string);
wc.Headers.Add("Authorization", authorization);
retry:
try {
this.Log().Info("Downloading url: " + (failedUrl ?? url));
return await this.WarnIfThrows(() => wc.DownloadDataTaskAsync(failedUrl ?? url),
"Failed to download url: " + (failedUrl ?? url)).ConfigureAwait(false);
} catch (Exception) {
// NB: Some super brain-dead services are case-sensitive yet
// corrupt case on upload. I can't even.
if (failedUrl != null) throw;
failedUrl = url.ToLower();
goto retry;
}
}
}
/// <inheritdoc />
public virtual async Task<string> DownloadString(string url, string authorization)
{
using (var wc = CreateWebClient()) {
var failedUrl = default(string);
wc.Headers.Add("Authorization", authorization);
retry:
try {
this.Log().Info("Downloading url: " + (failedUrl ?? url));
return await this.WarnIfThrows(() => wc.DownloadStringTaskAsync(failedUrl ?? url),
"Failed to download url: " + (failedUrl ?? url)).ConfigureAwait(false);
} catch (Exception) {
// NB: Some super brain-dead services are case-sensitive yet
// corrupt case on upload. I can't even.
if (failedUrl != null) throw;
failedUrl = url.ToLower();
goto retry;
}
}
}
/// <summary>
/// Creates and returns a new WebClient for every requst
/// </summary>
protected virtual WebClient CreateWebClient()
{
var ret = new WebClient();
var wp = WebRequest.DefaultWebProxy;
if (wp != null) {
wp.Credentials = CredentialCache.DefaultCredentials;
ret.Proxy = wp;
}
return ret;
}
}
}

View File

@@ -5,22 +5,21 @@ using System.Runtime.Serialization;
using System.Text;
using System.Threading.Tasks;
using Squirrel.Json;
using Squirrel.Sources;
namespace Squirrel
{
/// <summary>
/// An implementation of UpdateManager which supports checking updates and
/// downloading releases directly from GitHub releases
/// downloading releases directly from GitHub releases. This class is just a shorthand
/// for initialising <see cref="UpdateManager"/> with a <see cref="GithubSource"/>
/// as the first argument.
/// </summary>
#if NET5_0_OR_GREATER
[System.Runtime.Versioning.SupportedOSPlatform("windows")]
#endif
public class GithubUpdateManager : UpdateManager
{
private readonly string _repoUrl;
private readonly string _accessToken;
private readonly bool _prerelease;
/// <inheritdoc cref="UpdateManager(string, string, string, IFileDownloader)"/>
/// <param name="repoUrl">
/// The URL of the GitHub repository to download releases from
@@ -56,89 +55,28 @@ namespace Squirrel
string applicationIdOverride = null,
string localAppDataDirectoryOverride = null,
IFileDownloader urlDownloader = null)
: base(null, applicationIdOverride, localAppDataDirectoryOverride, urlDownloader)
: base(new GithubSource(repoUrl, accessToken, prerelease, urlDownloader), applicationIdOverride, localAppDataDirectoryOverride)
{
_repoUrl = repoUrl;
_accessToken = accessToken;
_prerelease = prerelease;
}
}
/// <inheritdoc />
public override async Task<UpdateInfo> CheckForUpdate(bool ignoreDeltaUpdates = false, Action<int> progress = null, UpdaterIntention intention = UpdaterIntention.Update)
public partial class UpdateManager
{
/// <summary>
/// This function is obsolete and will be removed in a future version,
/// see the <see cref="GithubUpdateManager" /> class for a replacement.
/// </summary>
[System.ComponentModel.EditorBrowsable(System.ComponentModel.EditorBrowsableState.Never)]
[Obsolete("Use 'new UpdateManager(new GithubSource(...))' instead")]
public static Task<UpdateManager> GitHubUpdateManager(
string repoUrl,
string applicationName = null,
string rootDirectory = null,
IFileDownloader urlDownloader = null,
bool prerelease = false,
string accessToken = null)
{
await EnsureReleaseUrl().ConfigureAwait(false);
return await base.CheckForUpdate(ignoreDeltaUpdates, progress, intention).ConfigureAwait(false);
}
/// <inheritdoc />
public override async Task DownloadReleases(IEnumerable<ReleaseEntry> releasesToDownload, Action<int> progress = null)
{
await EnsureReleaseUrl().ConfigureAwait(false);
await base.DownloadReleases(releasesToDownload, progress).ConfigureAwait(false);
}
private async Task EnsureReleaseUrl()
{
if (this._updateUrlOrPath == null) {
this._updateUrlOrPath = await GetLatestGithubReleaseUrl().ConfigureAwait(false);
}
}
private async Task<string> GetLatestGithubReleaseUrl()
{
var repoUri = new Uri(_repoUrl);
var releases = await GetGithubReleases(repoUri, _accessToken, _prerelease, _urlDownloader).ConfigureAwait(false);
return releases.First().DownloadUrl;
}
internal static async Task<IEnumerable<GithubRelease>> GetGithubReleases(Uri repoUri, string token, bool prerelease, IFileDownloader downloader)
{
if (repoUri.Segments.Length != 3) {
throw new Exception("Repo URL must be to the root URL of the repo e.g. https://github.com/myuser/myrepo");
}
var releasesApiBuilder = new StringBuilder("repos")
.Append(repoUri.AbsolutePath)
.Append("/releases");
Uri baseAddress;
if (repoUri.Host.EndsWith("github.com", StringComparison.OrdinalIgnoreCase)) {
baseAddress = new Uri("https://api.github.com/");
} else {
// if it's not github.com, it's probably an Enterprise server
// now the problem with Enterprise is that the API doesn't come prefixed
// it comes suffixed so the API path of http://internal.github.server.local
// API location is http://interal.github.server.local/api/v3
baseAddress = new Uri(string.Format("{0}{1}{2}/api/v3/", repoUri.Scheme, Uri.SchemeDelimiter, repoUri.Host));
}
// above ^^ notice the end slashes for the baseAddress, explained here: http://stackoverflow.com/a/23438417/162694
string bearer = null;
if (!string.IsNullOrWhiteSpace(token))
bearer = "Bearer " + token;
var fullPath = new Uri(baseAddress, releasesApiBuilder.ToString());
var response = await downloader.DownloadString(fullPath.ToString(), bearer).ConfigureAwait(false);
var releases = SimpleJson.DeserializeObject<List<GithubRelease>>(response);
return releases.OrderByDescending(d => d.PublishedAt).Where(x => prerelease || !x.Prerelease);
}
[DataContract]
internal class GithubRelease
{
[DataMember(Name = "prerelease")]
public bool Prerelease { get; set; }
[DataMember(Name = "published_at")]
public DateTime PublishedAt { get; set; }
[DataMember(Name = "html_url")]
public string HtmlUrl { get; set; }
public string DownloadUrl => HtmlUrl.Replace("/tag/", "/download/");
return Task.FromResult(new UpdateManager(new GithubSource(repoUrl, accessToken, prerelease, urlDownloader), applicationName, rootDirectory));
}
}
}

View File

@@ -122,9 +122,9 @@ namespace Squirrel
}
}
public static IFileDownloader CreateDefaultDownloader()
public static Sources.IFileDownloader CreateDefaultDownloader()
{
return new HttpClientFileDownloader();
return new Sources.HttpClientFileDownloader();
}
public static async Task CopyToAsync(string from, string to)
@@ -174,6 +174,30 @@ namespace Squirrel
}
}
public static async Task RetryAsync(this Func<Task> block, int retries = 4, int retryDelay = 250)
{
while (true) {
try {
await block().ConfigureAwait(false);
} catch {
if (retries-- == 0) throw;
await Task.Delay(retryDelay).ConfigureAwait(false);
}
}
}
public static async Task<T> RetryAsync<T>(this Func<Task<T>> block, int retries = 4, int retryDelay = 250)
{
while (true) {
try {
return await block().ConfigureAwait(false);
} catch {
if (retries-- == 0) throw;
await Task.Delay(retryDelay).ConfigureAwait(false);
}
}
}
/*
* caesay 09/12/2021 at 12:10 PM
* yeah

View File

@@ -10,6 +10,7 @@ using System.Threading;
using System.Threading.Tasks;
using Microsoft.Win32;
using Squirrel.SimpleSplat;
using Squirrel.Sources;
namespace Squirrel
{

View File

@@ -0,0 +1,229 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.Serialization;
using System.Threading.Tasks;
using Squirrel.Json;
namespace Squirrel.Sources
{
/// <summary> Describes a GitHub release, including attached assets. </summary>
[DataContract]
public class GithubRelease
{
/// <summary> The name of this release. </summary>
[DataMember(Name = "name")]
public string Name { get; set; }
/// <summary> True if this release is a prerelease. </summary>
[DataMember(Name = "prerelease")]
public bool Prerelease { get; set; }
/// <summary> The date which this release was published publically. </summary>
[DataMember(Name = "published_at")]
public DateTime PublishedAt { get; set; }
/// <summary> A list of assets (files) uploaded to this release. </summary>
[DataMember(Name = "assets")]
public GithubReleaseAsset[] Assets { get; set; }
}
/// <summary> Describes a asset (file) uploaded to a GitHub release. </summary>
[DataContract]
public class GithubReleaseAsset
{
/// <summary>
/// The asset URL for this release asset. Requests to this URL will use API
/// quota and return JSON unless the 'Accept' header is "application/octet-stream".
/// </summary>
[DataMember(Name = "url")]
public string Url { get; set; }
/// <summary>
/// The browser URL for this release asset. This does not use API quota,
/// however this URL only works for public repositories. If downloading
/// assets from a private repository, the <see cref="Url"/> property must
/// be used with an appropriate access token.
/// </summary>
[DataMember(Name = "browser_download_url")]
public string BrowserDownloadUrl { get; set; }
/// <summary> The name of this release asset. </summary>
[DataMember(Name = "name")]
public string Name { get; set; }
/// <summary> The mime type of this release asset (as detected by GitHub). </summary>
[DataMember(Name = "content_type")]
public string ContentType { get; set; }
}
/// <summary>
/// Retrieves available releases from a GitHub repository. This class only
/// downloads assets from the very latest GitHub release.
/// </summary>
public class GithubSource : IUpdateSource
{
/// <summary>
/// The URL of the GitHub repository to download releases from
/// (e.g. https://github.com/myuser/myrepo)
/// </summary>
public virtual Uri RepoUri { get; }
/// <summary>
/// If true, the latest pre-release will be downloaded. If false, the latest
/// stable release will be downloaded.
/// </summary>
public virtual bool Prerelease { get; }
/// <summary>
/// The file downloader used to perform HTTP requests.
/// </summary>
public virtual IFileDownloader Downloader { get; }
/// <summary>
/// The GitHub release which this class should download assets from when
/// executing <see cref="DownloadReleaseEntry"/>. This property can be set
/// explicitly, otherwise it will also be set automatically when executing
/// <see cref="GetReleaseFeed(Guid?, ReleaseEntry)"/>.
/// </summary>
public virtual GithubRelease Release { get; set; }
/// <summary>
/// The GitHub access token to use with the request to download releases.
/// If left empty, the GitHub rate limit for unauthenticated requests allows
/// for up to 60 requests per hour, limited by IP address.
/// </summary>
protected virtual string AccessToken { get; }
/// <summary> </summary>
protected virtual string Authorization => String.IsNullOrWhiteSpace(AccessToken) ? null : "Bearer " + AccessToken;
/// <inheritdoc cref="GithubSource" />
/// <param name="repoUrl">
/// The URL of the GitHub repository to download releases from
/// (e.g. https://github.com/myuser/myrepo)
/// </param>
/// <param name="accessToken">
/// The GitHub access token to use with the request to download releases.
/// If left empty, the GitHub rate limit for unauthenticated requests allows
/// for up to 60 requests per hour, limited by IP address.
/// </param>
/// <param name="prerelease">
/// If true, the latest pre-release will be downloaded. If false, the latest
/// stable release will be downloaded.
/// </param>
/// <param name="downloader">
/// The file downloader used to perform HTTP requests.
/// </param>
public GithubSource(string repoUrl, string accessToken, bool prerelease, IFileDownloader downloader = null)
{
RepoUri = new Uri(repoUrl);
AccessToken = accessToken;
Prerelease = prerelease;
Downloader = downloader ?? Utility.CreateDefaultDownloader();
}
/// <inheritdoc />
public virtual async Task<ReleaseEntry[]> GetReleaseFeed(Guid? stagingId = null, ReleaseEntry latestLocalRelease = null)
{
var releases = await GetReleases(Prerelease).ConfigureAwait(false);
if (releases == null || releases.Count() == 0)
throw new Exception($"No GitHub releases found at '{RepoUri}'.");
// CS: we 'cache' the release here, so subsequent calls to DownloadReleaseEntry
// will download assets from the same release in which we returned ReleaseEntry's
// from. A better architecture would be to return an array of "GithubReleaseEntry"
// containing a reference to the GithubReleaseAsset instead.
Release = releases.First();
// this might be a browser url or an api url (depending on whether we have a AccessToken or not)
// https://docs.github.com/en/rest/reference/releases#get-a-release-asset
var assetUrl = GetAssetUrlFromName(Release, "RELEASES");
var releaseBytes = await Downloader.DownloadBytes(assetUrl, Authorization, "application/octet-stream").ConfigureAwait(false);
var txt = Utility.RemoveByteOrderMarkerIfPresent(releaseBytes);
return ReleaseEntry.ParseReleaseFileAndApplyStaging(txt, stagingId).ToArray();
}
/// <inheritdoc />
public virtual Task DownloadReleaseEntry(ReleaseEntry releaseEntry, string localFile, Action<int> progress)
{
if (Release == null) {
throw new InvalidOperationException("No GitHub Release specified. Call GetReleaseFeed or set " +
"GithubSource.Release before calling this function.");
}
// this might be a browser url or an api url (depending on whether we have a AccessToken or not)
// https://docs.github.com/en/rest/reference/releases#get-a-release-asset
var assetUrl = GetAssetUrlFromName(Release, releaseEntry.Filename);
return Downloader.DownloadFile(assetUrl, localFile, progress, Authorization, "application/octet-stream");
}
/// <summary>
/// Retrieves a list of <see cref="GithubRelease"/> from the current repository.
/// </summary>
public virtual async Task<GithubRelease[]> GetReleases(bool includePrereleases, int perPage = 30, int page = 1)
{
// https://docs.github.com/en/rest/reference/releases
var releasesPath = $"repos{RepoUri.AbsolutePath}/releases?per_page={perPage}&page={page}";
var baseUri = GetApiBaseUrl(RepoUri);
var getReleasesUri = new Uri(baseUri, releasesPath);
var response = await Downloader.DownloadString(getReleasesUri.ToString(), Authorization, "application/vnd.github.v3+json").ConfigureAwait(false);
var releases = SimpleJson.DeserializeObject<List<GithubRelease>>(response);
return releases.OrderByDescending(d => d.PublishedAt).Where(x => includePrereleases || !x.Prerelease).ToArray();
}
/// <summary>
/// Given a <see cref="GithubRelease"/> and an asset filename (eg. 'RELEASES') this
/// function will return either <see cref="GithubReleaseAsset.BrowserDownloadUrl"/> or
/// <see cref="GithubReleaseAsset.Url"/>, depending whether an access token is available
/// or not. Throws if the specified release has no matching assets.
/// </summary>
protected virtual string GetAssetUrlFromName(GithubRelease release, string assetName)
{
if (release.Assets == null || release.Assets.Count() == 0) {
throw new ArgumentException($"No assets found in Github Release '{release.Name}'.");
}
IEnumerable<GithubReleaseAsset> allReleasesFiles = release.Assets.Where(a => a.Name.Equals(assetName, StringComparison.InvariantCultureIgnoreCase));
if (allReleasesFiles == null || allReleasesFiles.Count() == 0) {
throw new ArgumentException($"Could not find asset called '{assetName}' in Github Release '{release.Name}'.");
}
var asset = allReleasesFiles.First();
if (String.IsNullOrWhiteSpace(AccessToken)) {
// if no AccessToken provided, we use the BrowserDownloadUrl which does not
// count towards the "unauthenticated api request" limit of 60 per hour per IP.
return asset.BrowserDownloadUrl;
} else {
// otherwise, we use the regular asset url, which will allow us to retrieve
// assets from private repositories
// https://docs.github.com/en/rest/reference/releases#get-a-release-asset
return asset.Url;
}
}
/// <summary>
/// Given a repository URL (e.g. https://github.com/myuser/myrepo) this function
/// returns the API base for performing requests. (eg. "https://api.github.com/"
/// or http://internal.github.server.local/api/v3)
/// </summary>
/// <param name="repoUrl"></param>
/// <returns></returns>
protected virtual Uri GetApiBaseUrl(Uri repoUrl)
{
Uri baseAddress;
if (repoUrl.Host.EndsWith("github.com", StringComparison.OrdinalIgnoreCase)) {
baseAddress = new Uri("https://api.github.com/");
} else {
// if it's not github.com, it's probably an Enterprise server
// now the problem with Enterprise is that the API doesn't come prefixed
// it comes suffixed so the API path of http://internal.github.server.local
// API location is http://internal.github.server.local/api/v3
baseAddress = new Uri(string.Format("{0}{1}{2}/api/v3/", repoUrl.Scheme, Uri.SchemeDelimiter, repoUrl.Host));
}
// above ^^ notice the end slashes for the baseAddress, explained here: http://stackoverflow.com/a/23438417/162694
return baseAddress;
}
}
}

View File

@@ -0,0 +1,129 @@
using System;
using System.IO;
using System.Net;
using System.Net.Http;
using System.Net.Http.Headers;
using System.Threading;
using System.Threading.Tasks;
namespace Squirrel.Sources
{
/// <inheritdoc cref="IFileDownloader"/>
public class HttpClientFileDownloader : IFileDownloader
{
/// <summary>
/// The User-Agent sent with Squirrel requests
/// </summary>
public static ProductInfoHeaderValue UserAgent => new("Squirrel", AssemblyRuntimeInfo.ExecutingAssemblyName.Version.ToString());
/// <inheritdoc />
public virtual async Task DownloadFile(string url, string targetFile, Action<int> progress, string authorization, string accept)
{
using var client = CreateHttpClient(authorization, accept);
try {
using (var fs = File.Open(targetFile, FileMode.Create)) {
await DownloadToStreamInternal(client, url, fs, progress).ConfigureAwait(false);
}
} catch {
// NB: Some super brain-dead services are case-sensitive yet
// corrupt case on upload. I can't even.
using (var fs = File.Open(targetFile, FileMode.Create)) {
await DownloadToStreamInternal(client, url.ToLower(), fs, progress).ConfigureAwait(false);
}
}
}
/// <inheritdoc />
public virtual async Task<byte[]> DownloadBytes(string url, string authorization, string accept)
{
using var client = CreateHttpClient(authorization, accept);
try {
return await client.GetByteArrayAsync(url).ConfigureAwait(false);
} catch {
// NB: Some super brain-dead services are case-sensitive yet
// corrupt case on upload. I can't even.
return await client.GetByteArrayAsync(url.ToLower()).ConfigureAwait(false);
}
}
/// <inheritdoc />
public virtual async Task<string> DownloadString(string url, string authorization, string accept)
{
using var client = CreateHttpClient(authorization, accept);
try {
return await client.GetStringAsync(url).ConfigureAwait(false);
} catch {
// NB: Some super brain-dead services are case-sensitive yet
// corrupt case on upload. I can't even.
return await client.GetStringAsync(url.ToLower()).ConfigureAwait(false);
}
}
/// <summary>
/// Asynchronously downloads a remote url to the specified destination stream while
/// providing progress updates.
/// </summary>
protected virtual async Task DownloadToStreamInternal(HttpClient client, string requestUri, Stream destination, Action<int> progress = null, CancellationToken cancellationToken = default)
{
// https://stackoverflow.com/a/46497896/184746
// Get the http headers first to examine the content length
using var response = await client.GetAsync(requestUri, HttpCompletionOption.ResponseHeadersRead).ConfigureAwait(false);
response.EnsureSuccessStatusCode();
var contentLength = response.Content.Headers.ContentLength;
using var download = await response.Content.ReadAsStreamAsync().ConfigureAwait(false);
// Ignore progress reporting when no progress reporter was
// passed or when the content length is unknown
if (progress == null || !contentLength.HasValue) {
await download.CopyToAsync(destination).ConfigureAwait(false);
return;
}
var buffer = new byte[81920];
long totalBytesRead = 0;
int bytesRead;
int lastProgress = 0;
while ((bytesRead = await download.ReadAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false)) != 0) {
await destination.WriteAsync(buffer, 0, bytesRead, cancellationToken).ConfigureAwait(false);
totalBytesRead += bytesRead;
// Convert absolute progress (bytes downloaded) into relative progress (0% - 100%)
// and don't report progress < 3% difference, kind of like a shitty debounce.
var curProgress = (int) ((double) totalBytesRead / contentLength.Value * 100);
if (curProgress - lastProgress >= 3) {
lastProgress = curProgress;
progress(curProgress);
}
}
if (lastProgress < 100)
progress(100);
}
/// <summary>
/// Creates a new <see cref="HttpClient"/> for every request. Override this
/// function to add a custom proxy or other http configuration.
/// </summary>
protected virtual HttpClient CreateHttpClient(string authorization, string accept)
{
var handler = new HttpClientHandler() {
AllowAutoRedirect = true,
MaxAutomaticRedirections = 10,
AutomaticDecompression = DecompressionMethods.GZip | DecompressionMethods.Deflate,
};
var client = new HttpClient(handler, true);
client.DefaultRequestHeaders.UserAgent.Add(UserAgent);
if (authorization != null)
client.DefaultRequestHeaders.Add("Authorization", authorization);
if (accept != null)
client.DefaultRequestHeaders.Add("Accept", accept);
return client;
}
}
}

View File

@@ -0,0 +1,40 @@
using System;
using System.Threading.Tasks;
using Squirrel.SimpleSplat;
namespace Squirrel.Sources
{
/// <summary>
/// A simple abstractable file downloader
/// </summary>
public interface IFileDownloader : IEnableLogger
{
/// <summary>
/// Downloads a remote file to the specified local path
/// </summary>
/// <param name="url">The url which will be downloaded.</param>
/// <param name="targetFile">
/// The local path where the file will be stored
/// If a file exists at this path, it will be overritten.</param>
/// <param name="progress">
/// A delegate for reporting download progress, with expected values from 0-100.
/// </param>
/// <param name="authorization">
/// Text to be sent in the 'Authorization' header of the request.
/// </param>
/// <param name="accept">
/// Text to be sent in the 'Accept' header of the request.
/// </param>
Task DownloadFile(string url, string targetFile, Action<int> progress, string authorization = null, string accept = null);
/// <summary>
/// Returns a byte array containing the contents of the file at the specified url
/// </summary>
Task<byte[]> DownloadBytes(string url, string authorization = null, string accept = null);
/// <summary>
/// Returns a string containing the contents of the specified url
/// </summary>
Task<string> DownloadString(string url, string authorization = null, string accept = null);
}
}

View File

@@ -0,0 +1,39 @@
using System;
using System.Threading.Tasks;
using Squirrel.SimpleSplat;
namespace Squirrel.Sources
{
/// <summary>
/// Abstraction for finding and downloading updates from a package source / repository.
/// An implementation may copy a file from a local repository, download from a web address,
/// or even use third party services and parse proprietary data to produce a package feed.
/// </summary>
public interface IUpdateSource : IEnableLogger
{
/// <summary>
/// Retrieve the list of available remote releases from the package source. These releases
/// can subsequently be downloaded with <see cref="DownloadReleaseEntry(ReleaseEntry, string, Action{int})"/>.
/// </summary>
/// <param name="stagingId">A persistent user-id, used for calculating whether a specific
/// release should be available to this user or not. (eg, for the purposes of rolling out
/// an update to only a small portion of users at a time).</param>
/// <param name="latestLocalRelease">The latest / current local release. If specified,
/// metadata from this package may be provided to the remote server (such as package id,
/// or cpu architecture) to ensure that the correct package is downloaded for this user.
/// </param>
/// <returns>An array of <see cref="ReleaseEntry"/> objects that are available for download
/// and are applicable to this user.</returns>
Task<ReleaseEntry[]> GetReleaseFeed(Guid? stagingId = null, ReleaseEntry latestLocalRelease = null);
/// <summary>
/// Download the specified <see cref="ReleaseEntry"/> to the provided local file path.
/// </summary>
/// <param name="releaseEntry">The release to download.</param>
/// <param name="localFile">The path on the local disk to store the file. If this file exists,
/// it will be overwritten.</param>
/// <param name="progress">This delegate will be executed with values from 0-100 as the
/// download is being processed.</param>
Task DownloadReleaseEntry(ReleaseEntry releaseEntry, string localFile, Action<int> progress);
}
}

View File

@@ -0,0 +1,63 @@
using System;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Squirrel.SimpleSplat;
namespace Squirrel.Sources
{
/// <summary>
/// Retrieves available updates from a local or network-attached disk. The directory
/// must contain one or more valid packages, as well as a 'RELEASES' index file.
/// </summary>
public class SimpleFileSource : IUpdateSource
{
/// <summary> The local directory containing packages to update to. </summary>
public virtual DirectoryInfo BaseDirectory { get; }
/// <inheritdoc cref="SimpleFileSource" />
/// <param name="baseDirectory">The directory where to search for packages.</param>
public SimpleFileSource(DirectoryInfo baseDirectory)
{
BaseDirectory = baseDirectory;
}
/// <inheritdoc />
public virtual Task<ReleaseEntry[]> GetReleaseFeed(Guid? stagingId = null, ReleaseEntry latestLocalRelease = null)
{
if (!BaseDirectory.Exists)
throw new Exception($"The local update directory '{BaseDirectory.FullName}' does not exist.");
var releasesPath = Path.Combine(BaseDirectory.FullName, "RELEASES");
this.Log().Info($"Reading RELEASES from '{releasesPath}'");
var fi = new FileInfo(releasesPath);
if (fi.Exists) {
var txt = File.ReadAllText(fi.FullName, encoding: Encoding.UTF8);
return Task.FromResult(ReleaseEntry.ParseReleaseFileAndApplyStaging(txt, stagingId).ToArray());
} else {
var packages = BaseDirectory.EnumerateFiles("*.nupkg");
if (packages.Any()) {
this.Log().Warn($"The file '{releasesPath}' does not exist but directory contains packages. " +
$"This is not valid but attempting to proceed anyway by writing new file.");
return Task.FromResult(ReleaseEntry.BuildReleasesFile(BaseDirectory.FullName).ToArray());
} else {
throw new Exception($"The file '{releasesPath}' does not exist. Cannot update from invalid source.");
}
}
}
/// <inheritdoc />
public virtual Task DownloadReleaseEntry(ReleaseEntry releaseEntry, string localFile, Action<int> progress)
{
var releasePath = Path.Combine(BaseDirectory.FullName, releaseEntry.Filename);
if (!File.Exists(releasePath))
throw new Exception($"The file '{releasePath}' does not exist. The packages directory is invalid.");
File.Copy(releasePath, localFile, true);
progress?.Invoke(100);
return Task.CompletedTask;
}
}
}

View File

@@ -0,0 +1,73 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Squirrel.SimpleSplat;
namespace Squirrel.Sources
{
/// <summary>
/// Retrieves updates from a static file host or other web server.
/// Will perform a request for '{baseUri}/RELEASES' to locate the available packages,
/// and provides query parameters to specify the name of the requested package.
/// </summary>
public class SimpleWebSource : IUpdateSource
{
/// <summary> The URL of the server hosting packages to update to. </summary>
public virtual Uri BaseUri { get; }
/// <summary> The <see cref="IFileDownloader"/> to be used for performing http requests. </summary>
public virtual IFileDownloader Downloader { get; }
/// <inheritdoc cref="SimpleWebSource" />
public SimpleWebSource(string baseUrl, IFileDownloader downloader = null)
: this(new Uri(baseUrl), downloader)
{ }
/// <inheritdoc cref="SimpleWebSource" />
public SimpleWebSource(Uri baseUri, IFileDownloader downloader = null)
{
BaseUri = baseUri;
Downloader = downloader ?? Utility.CreateDefaultDownloader();
}
/// <inheritdoc />
public virtual async Task<ReleaseEntry[]> GetReleaseFeed(Guid? stagingId = null, ReleaseEntry latestLocalRelease = null)
{
var uri = Utility.AppendPathToUri(BaseUri, "RELEASES");
var args = new Dictionary<string, string> {
{ "arch", AssemblyRuntimeInfo.Architecture.ToString().ToLower() }
};
if (latestLocalRelease != null) {
args.Add("id", latestLocalRelease.PackageName);
args.Add("localVersion", latestLocalRelease.Version.ToString());
}
var uriAndQuery = Utility.AddQueryParamsToUri(uri, args);
this.Log().Info($"Downloading RELEASES from '{uriAndQuery}'.");
var bytes = await Downloader.DownloadBytes(uriAndQuery.ToString()).ConfigureAwait(false);
var txt = Utility.RemoveByteOrderMarkerIfPresent(bytes);
return ReleaseEntry.ParseReleaseFileAndApplyStaging(txt, stagingId).ToArray();
}
/// <inheritdoc />
public virtual Task DownloadReleaseEntry(ReleaseEntry releaseEntry, string localFile, Action<int> progress)
{
var baseUri = Utility.EnsureTrailingSlash(BaseUri);
var uri = Utility.AppendPathToUri(new Uri(releaseEntry.BaseUrl), releaseEntry.Filename).ToString();
if (!String.IsNullOrEmpty(releaseEntry.Query)) {
uri += releaseEntry.Query;
}
var source = new Uri(baseUri, uri).AbsolutePath;
this.Log().Info($"Downloading '{releaseEntry.Filename}' from '{source}'.");
return Downloader.DownloadFile(source, localFile, progress);
}
}
}

View File

@@ -0,0 +1,117 @@
using System;
using System.Net;
using System.Threading.Tasks;
using Squirrel.SimpleSplat;
namespace Squirrel.Sources
{
/// This class is obsolete. Use <see cref="HttpClientFileDownloader"/> instead.
[Obsolete("Use HttpClientFileDownloader")]
public class WebClientFileDownloader : IFileDownloader
{
/// <inheritdoc />
public virtual async Task DownloadFile(string url, string targetFile, Action<int> progress, string authorization, string accept)
{
using (var wc = CreateWebClient(authorization, accept)) {
var failedUrl = default(string);
var lastSignalled = DateTime.MinValue;
wc.DownloadProgressChanged += (sender, args) => {
var now = DateTime.Now;
if (now - lastSignalled > TimeSpan.FromMilliseconds(500)) {
lastSignalled = now;
progress(args.ProgressPercentage);
}
};
retry:
try {
this.Log().Info("Downloading file: " + (failedUrl ?? url));
await this.WarnIfThrows(
async () => {
await wc.DownloadFileTaskAsync(failedUrl ?? url, targetFile).ConfigureAwait(false);
progress(100);
},
"Failed downloading URL: " + (failedUrl ?? url)).ConfigureAwait(false);
} catch (Exception) {
// NB: Some super brain-dead services are case-sensitive yet
// corrupt case on upload. I can't even.
if (failedUrl != null) throw;
failedUrl = url.ToLower();
progress(0);
goto retry;
}
}
}
/// <inheritdoc />
public virtual async Task<byte[]> DownloadBytes(string url, string authorization, string accept)
{
using (var wc = CreateWebClient(authorization, accept)) {
var failedUrl = default(string);
retry:
try {
this.Log().Info("Downloading url: " + (failedUrl ?? url));
return await this.WarnIfThrows(() => wc.DownloadDataTaskAsync(failedUrl ?? url),
"Failed to download url: " + (failedUrl ?? url)).ConfigureAwait(false);
} catch (Exception) {
// NB: Some super brain-dead services are case-sensitive yet
// corrupt case on upload. I can't even.
if (failedUrl != null) throw;
failedUrl = url.ToLower();
goto retry;
}
}
}
/// <inheritdoc />
public virtual async Task<string> DownloadString(string url, string authorization, string accept)
{
using (var wc = CreateWebClient(authorization, accept)) {
var failedUrl = default(string);
retry:
try {
this.Log().Info("Downloading url: " + (failedUrl ?? url));
return await this.WarnIfThrows(() => wc.DownloadStringTaskAsync(failedUrl ?? url),
"Failed to download url: " + (failedUrl ?? url)).ConfigureAwait(false);
} catch (Exception) {
// NB: Some super brain-dead services are case-sensitive yet
// corrupt case on upload. I can't even.
if (failedUrl != null) throw;
failedUrl = url.ToLower();
goto retry;
}
}
}
/// <summary>
/// Creates and returns a new WebClient for every requst
/// </summary>
protected virtual WebClient CreateWebClient(string authorization, string accept)
{
var ret = new WebClient();
var wp = WebRequest.DefaultWebProxy;
if (wp != null) {
wp.Credentials = CredentialCache.DefaultCredentials;
ret.Proxy = wp;
}
if (authorization != null)
ret.Headers.Add("Authorization", authorization);
if (accept != null)
ret.Headers.Add("Accept", accept);
return ret;
}
}
}

View File

@@ -121,7 +121,7 @@ namespace Squirrel
// functions which acquire a lock are exposed to the consumer.
// also this means the "urlOrPath" param will never be used,
// so we can pass null safely.
var um = new UpdateManager(null);
var um = new UpdateManager();
um.Dispose();
// in the fastExitLookup arguments, we run the squirrel hook and then exit the process

View File

@@ -11,211 +11,134 @@ namespace Squirrel
{
public partial class UpdateManager
{
internal class CheckForUpdateImpl : IEnableLogger
/// <inheritdoc />
public virtual async Task<UpdateInfo> CheckForUpdate(
bool ignoreDeltaUpdates = false,
Action<int> progress = null,
UpdaterIntention intention = UpdaterIntention.Update)
{
readonly string rootAppDirectory;
// lock will be held until this class is disposed
await acquireUpdateLock().ConfigureAwait(false);
public CheckForUpdateImpl(string rootAppDirectory)
{
this.rootAppDirectory = rootAppDirectory;
progress = progress ?? (_ => { });
var localReleases = Enumerable.Empty<ReleaseEntry>();
var stagingId = intention == UpdaterIntention.Install ? null : getOrCreateStagedUserId();
bool shouldInitialize = intention == UpdaterIntention.Install;
var localReleaseFile = Utility.LocalReleaseFileForAppDir(AppDirectory);
if (intention != UpdaterIntention.Install) {
try {
localReleases = Utility.LoadLocalReleases(localReleaseFile);
} catch (Exception ex) {
// Something has gone pear-shaped, let's start from scratch
this.Log().WarnException("Failed to load local releases, starting from scratch", ex);
shouldInitialize = true;
}
}
public async Task<UpdateInfo> CheckForUpdate(
UpdaterIntention intention,
string localReleaseFile,
string updateUrlOrPath,
bool ignoreDeltaUpdates = false,
Action<int> progress = null,
IFileDownloader urlDownloader = null)
{
progress = progress ?? (_ => { });
if (shouldInitialize) initializeClientAppDirectory();
var localReleases = Enumerable.Empty<ReleaseEntry>();
var stagingId = intention == UpdaterIntention.Install ? null : getOrCreateStagedUserId();
var latestLocalRelease = localReleases != null && localReleases.Count() > 0
? localReleases.MaxBy(v => v.Version).First() :
null;
bool shouldInitialize = intention == UpdaterIntention.Install;
progress(33);
if (intention != UpdaterIntention.Install) {
try {
localReleases = Utility.LoadLocalReleases(localReleaseFile);
} catch (Exception ex) {
// Something has gone pear-shaped, let's start from scratch
this.Log().WarnException("Failed to load local releases, starting from scratch", ex);
shouldInitialize = true;
}
}
var remoteReleases = await Utility.RetryAsync(() => _updateSource.GetReleaseFeed(stagingId, latestLocalRelease)).ConfigureAwait(false);
if (shouldInitialize) initializeClientAppDirectory();
progress(66);
string releaseFile;
var updateInfo = determineUpdateInfo(intention, localReleases, remoteReleases, ignoreDeltaUpdates);
var latestLocalRelease = localReleases.Count() > 0 ?
localReleases.MaxBy(x => x.Version).First() :
default(ReleaseEntry);
progress(100);
return updateInfo;
}
// Fetch the remote RELEASES file, whether it's a local dir or an
// HTTP URL
if (Utility.IsHttpUrl(updateUrlOrPath)) {
if (updateUrlOrPath.EndsWith("/")) {
updateUrlOrPath = updateUrlOrPath.Substring(0, updateUrlOrPath.Length - 1);
}
void initializeClientAppDirectory()
{
// On bootstrap, we won't have any of our directories, create them
var pkgDir = PackagesDirectory;
if (Directory.Exists(pkgDir)) {
Utility.DeleteFileOrDirectoryHardOrGiveUp(pkgDir);
}
this.Log().Info("Downloading RELEASES file from {0}", updateUrlOrPath);
Directory.CreateDirectory(pkgDir);
}
int retries = 3;
UpdateInfo determineUpdateInfo(UpdaterIntention intention, IEnumerable<ReleaseEntry> localReleases, IEnumerable<ReleaseEntry> remoteReleases, bool ignoreDeltaUpdates)
{
var packageDirectory = PackagesDirectory;
localReleases = localReleases ?? Enumerable.Empty<ReleaseEntry>();
retry:
if (remoteReleases == null) {
this.Log().Warn("Release information couldn't be determined due to remote corrupt RELEASES file");
throw new Exception("Corrupt remote RELEASES file");
}
try {
var uri = Utility.AppendPathToUri(new Uri(updateUrlOrPath), "RELEASES");
if (!remoteReleases.Any()) {
throw new Exception("Remote release File is empty or corrupted");
}
if (latestLocalRelease != null) {
uri = Utility.AddQueryParamsToUri(uri, new Dictionary<string, string> {
{ "id", latestLocalRelease.PackageName },
{ "localVersion", latestLocalRelease.Version.ToString() },
{ "arch", Environment.Is64BitOperatingSystem ? "amd64" : "x86" }
});
}
var latestFullRelease = Utility.FindCurrentVersion(remoteReleases);
var currentRelease = Utility.FindCurrentVersion(localReleases);
releaseFile = await urlDownloader.DownloadString(uri.ToString()).ConfigureAwait(false);
} catch (WebException ex) {
this.Log().InfoException("Download resulted in WebException (returning blank release list)", ex);
if (latestFullRelease == currentRelease) {
this.Log().Info("No updates, remote and local are the same");
if (retries <= 0) throw;
retries--;
goto retry;
}
var info = UpdateInfo.Create(currentRelease, new[] { latestFullRelease }, packageDirectory);
return info;
}
progress(33);
if (ignoreDeltaUpdates) {
remoteReleases = remoteReleases.Where(x => !x.IsDelta);
}
if (!localReleases.Any()) {
if (intention == UpdaterIntention.Install) {
this.Log().Info("First run, starting from scratch");
} else {
this.Log().Info("Reading RELEASES file from {0}", updateUrlOrPath);
if (!Directory.Exists(updateUrlOrPath)) {
var message = String.Format(
"The directory {0} does not exist, something is probably broken with your application",
updateUrlOrPath);
throw new Exception(message);
}
var fi = new FileInfo(Path.Combine(updateUrlOrPath, "RELEASES"));
if (!fi.Exists) {
var message = String.Format(
"The file {0} does not exist, something is probably broken with your application",
fi.FullName);
this.Log().Warn(message);
var packages = (new DirectoryInfo(updateUrlOrPath)).GetFiles("*.nupkg");
if (packages.Length == 0) {
throw new Exception(message);
}
// NB: Create a new RELEASES file since we've got a directory of packages
ReleaseEntry.WriteReleaseFile(
packages.Select(x => ReleaseEntry.GenerateFromFile(x.FullName)), fi.FullName);
}
releaseFile = File.ReadAllText(fi.FullName, Encoding.UTF8);
progress(33);
this.Log().Warn("No local releases found, starting from scratch");
}
var ret = default(UpdateInfo);
var remoteReleases = ReleaseEntry.ParseReleaseFileAndApplyStaging(releaseFile, stagingId);
progress(66);
return UpdateInfo.Create(null, new[] { latestFullRelease }, packageDirectory);
}
if (!remoteReleases.Any()) {
throw new Exception("Remote release File is empty or corrupted");
if (localReleases.Max(x => x.Version) > remoteReleases.Max(x => x.Version)) {
this.Log().Warn("hwhat, local version is greater than remote version");
return UpdateInfo.Create(Utility.FindCurrentVersion(localReleases), new[] { latestFullRelease }, packageDirectory);
}
return UpdateInfo.Create(currentRelease, remoteReleases, packageDirectory);
}
internal Guid? getOrCreateStagedUserId()
{
var stagedUserIdFile = Path.Combine(PackagesDirectory, ".betaId");
var ret = default(Guid);
try {
if (!Guid.TryParse(File.ReadAllText(stagedUserIdFile, Encoding.UTF8), out ret)) {
throw new Exception("File was read but contents were invalid");
}
ret = determineUpdateInfo(intention, localReleases, remoteReleases, ignoreDeltaUpdates);
progress(100);
this.Log().Info("Using existing staging user ID: {0}", ret.ToString());
return ret;
} catch (Exception ex) {
this.Log().DebugException("Couldn't read staging user ID, creating a blank one", ex);
}
void initializeClientAppDirectory()
{
// On bootstrap, we won't have any of our directories, create them
var pkgDir = Path.Combine(rootAppDirectory, "packages");
if (Directory.Exists(pkgDir)) {
Utility.DeleteFileOrDirectoryHardOrGiveUp(pkgDir);
}
var prng = new Random();
var buf = new byte[4096];
prng.NextBytes(buf);
Directory.CreateDirectory(pkgDir);
}
UpdateInfo determineUpdateInfo(UpdaterIntention intention, IEnumerable<ReleaseEntry> localReleases, IEnumerable<ReleaseEntry> remoteReleases, bool ignoreDeltaUpdates)
{
var packageDirectory = Utility.PackageDirectoryForAppDir(rootAppDirectory);
localReleases = localReleases ?? Enumerable.Empty<ReleaseEntry>();
if (remoteReleases == null) {
this.Log().Warn("Release information couldn't be determined due to remote corrupt RELEASES file");
throw new Exception("Corrupt remote RELEASES file");
}
var latestFullRelease = Utility.FindCurrentVersion(remoteReleases);
var currentRelease = Utility.FindCurrentVersion(localReleases);
if (latestFullRelease == currentRelease) {
this.Log().Info("No updates, remote and local are the same");
var info = UpdateInfo.Create(currentRelease, new[] { latestFullRelease }, packageDirectory);
return info;
}
if (ignoreDeltaUpdates) {
remoteReleases = remoteReleases.Where(x => !x.IsDelta);
}
if (!localReleases.Any()) {
if (intention == UpdaterIntention.Install) {
this.Log().Info("First run, starting from scratch");
} else {
this.Log().Warn("No local releases found, starting from scratch");
}
return UpdateInfo.Create(null, new[] { latestFullRelease }, packageDirectory);
}
if (localReleases.Max(x => x.Version) > remoteReleases.Max(x => x.Version)) {
this.Log().Warn("hwhat, local version is greater than remote version");
return UpdateInfo.Create(Utility.FindCurrentVersion(localReleases), new[] { latestFullRelease }, packageDirectory);
}
return UpdateInfo.Create(currentRelease, remoteReleases, packageDirectory);
}
internal Guid? getOrCreateStagedUserId()
{
var stagedUserIdFile = Path.Combine(rootAppDirectory, "packages", ".betaId");
var ret = default(Guid);
try {
if (!Guid.TryParse(File.ReadAllText(stagedUserIdFile, Encoding.UTF8), out ret)) {
throw new Exception("File was read but contents were invalid");
}
this.Log().Info("Using existing staging user ID: {0}", ret.ToString());
return ret;
} catch (Exception ex) {
this.Log().DebugException("Couldn't read staging user ID, creating a blank one", ex);
}
var prng = new Random();
var buf = new byte[4096];
prng.NextBytes(buf);
ret = Utility.CreateGuidFromHash(buf);
try {
File.WriteAllText(stagedUserIdFile, ret.ToString(), Encoding.UTF8);
this.Log().Info("Generated new staging user ID: {0}", ret.ToString());
return ret;
} catch (Exception ex) {
this.Log().WarnException("Couldn't write out staging user ID, this user probably shouldn't get beta anything", ex);
return null;
}
ret = Utility.CreateGuidFromHash(buf);
try {
File.WriteAllText(stagedUserIdFile, ret.ToString(), Encoding.UTF8);
this.Log().Info("Generated new staging user ID: {0}", ret.ToString());
return ret;
} catch (Exception ex) {
this.Log().WarnException("Couldn't write out staging user ID, this user probably shouldn't get beta anything", ex);
return null;
}
}
}

View File

@@ -9,106 +9,57 @@ namespace Squirrel
{
public partial class UpdateManager
{
internal class DownloadReleasesImpl : IEnableLogger
/// <inheritdoc />
public virtual async Task DownloadReleases(IEnumerable<ReleaseEntry> releasesToDownload, Action<int> progress = null)
{
readonly string rootAppDirectory;
// lock will be held until this class is disposed
await acquireUpdateLock().ConfigureAwait(false);
public DownloadReleasesImpl(string rootAppDirectory)
{
this.rootAppDirectory = rootAppDirectory;
}
progress = progress ?? (_ => { });
var packagesDirectory = PackagesDirectory;
public async Task DownloadReleases(string updateUrlOrPath, IEnumerable<ReleaseEntry> releasesToDownload, Action<int> progress = null, IFileDownloader urlDownloader = null)
{
progress = progress ?? (_ => { });
urlDownloader = urlDownloader ?? Utility.CreateDefaultDownloader();
var packagesDirectory = Path.Combine(rootAppDirectory, "packages");
double current = 0;
double toIncrement = 100.0 / releasesToDownload.Count();
double current = 0;
double toIncrement = 100.0 / releasesToDownload.Count();
if (Utility.IsHttpUrl(updateUrlOrPath)) {
// From Internet
await releasesToDownload.ForEachAsync(async x => {
var targetFile = Path.Combine(packagesDirectory, x.Filename);
double component = 0;
await downloadRelease(updateUrlOrPath, x, urlDownloader, targetFile, p => {
lock (progress) {
current -= component;
component = toIncrement / 100.0 * p;
progress((int) Math.Round(current += component));
}
}).ConfigureAwait(false);
checksumPackage(x);
}).ConfigureAwait(false);
} else {
// From Disk
await releasesToDownload.ForEachAsync(x => {
var targetFile = Path.Combine(packagesDirectory, x.Filename);
File.Copy(
Path.Combine(updateUrlOrPath, x.Filename),
targetFile,
true);
lock (progress) progress((int) Math.Round(current += toIncrement));
checksumPackage(x);
}).ConfigureAwait(false);
}
}
bool isReleaseExplicitlyHttp(ReleaseEntry x)
{
return x.BaseUrl != null &&
Uri.IsWellFormedUriString(x.BaseUrl, UriKind.Absolute);
}
Task downloadRelease(string updateBaseUrl, ReleaseEntry releaseEntry, IFileDownloader urlDownloader, string targetFile, Action<int> progress)
{
var baseUri = Utility.EnsureTrailingSlash(new Uri(updateBaseUrl));
var releaseEntryUrl = releaseEntry.BaseUrl + releaseEntry.Filename;
if (!String.IsNullOrEmpty(releaseEntry.Query)) {
releaseEntryUrl += releaseEntry.Query;
}
var sourceFileUrl = new Uri(baseUri, releaseEntryUrl).AbsoluteUri;
File.Delete(targetFile);
return urlDownloader.DownloadFile(sourceFileUrl, targetFile, progress);
}
Task checksumAllPackages(IEnumerable<ReleaseEntry> releasesDownloaded)
{
return releasesDownloaded.ForEachAsync(x => checksumPackage(x));
}
void checksumPackage(ReleaseEntry downloadedRelease)
{
var targetPackage = new FileInfo(
Path.Combine(rootAppDirectory, "packages", downloadedRelease.Filename));
if (!targetPackage.Exists) {
this.Log().Error("File {0} should exist but doesn't", targetPackage.FullName);
throw new Exception("Checksummed file doesn't exist: " + targetPackage.FullName);
}
if (targetPackage.Length != downloadedRelease.Filesize) {
this.Log().Error("File Length should be {0}, is {1}", downloadedRelease.Filesize, targetPackage.Length);
targetPackage.Delete();
throw new Exception("Checksummed file size doesn't match: " + targetPackage.FullName);
}
using (var file = targetPackage.OpenRead()) {
var hash = Utility.CalculateStreamSHA1(file);
if (!hash.Equals(downloadedRelease.SHA1, StringComparison.OrdinalIgnoreCase)) {
this.Log().Error("File SHA1 should be {0}, is {1}", downloadedRelease.SHA1, hash);
targetPackage.Delete();
throw new Exception("Checksum doesn't match: " + targetPackage.FullName);
await releasesToDownload.ForEachAsync(async x => {
var targetFile = Path.Combine(packagesDirectory, x.Filename);
double component = 0;
await _updateSource.DownloadReleaseEntry(x, targetFile, p => {
lock (progress) {
current -= component;
component = toIncrement / 100.0 * p;
progress((int) Math.Round(current += component));
}
}).ConfigureAwait(false);
checksumPackage(x);
}).ConfigureAwait(false);
}
void checksumPackage(ReleaseEntry downloadedRelease)
{
var targetPackage = new FileInfo(Path.Combine(PackagesDirectory, downloadedRelease.Filename));
if (!targetPackage.Exists) {
this.Log().Error("File {0} should exist but doesn't", targetPackage.FullName);
throw new Exception("Checksummed file doesn't exist: " + targetPackage.FullName);
}
if (targetPackage.Length != downloadedRelease.Filesize) {
this.Log().Error("File Length should be {0}, is {1}", downloadedRelease.Filesize, targetPackage.Length);
targetPackage.Delete();
throw new Exception("Checksummed file size doesn't match: " + targetPackage.FullName);
}
using (var file = targetPackage.OpenRead()) {
var hash = Utility.CalculateStreamSHA1(file);
if (!hash.Equals(downloadedRelease.SHA1, StringComparison.OrdinalIgnoreCase)) {
this.Log().Error("File SHA1 should be {0}, is {1}", downloadedRelease.SHA1, hash);
targetPackage.Delete();
throw new Exception("Checksum doesn't match: " + targetPackage.FullName);
}
}
}

View File

@@ -1,88 +0,0 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using System.Net.Http.Headers;
using System.Runtime.Serialization;
using System.Text;
using System.Threading.Tasks;
using Squirrel.Json;
namespace Squirrel
{
public partial class UpdateManager
{
[DataContract]
private class Release
{
[DataMember(Name = "prerelease")]
public bool Prerelease { get; set; }
[DataMember(Name = "published_at")]
public DateTime PublishedAt { get; set; }
[DataMember(Name = "html_url")]
public string HtmlUrl { get; set; }
}
/// <summary>
/// This function is obsolete and will be removed in a future version,
/// see the <see cref="GithubUpdateManager" /> class for a replacement.
/// </summary>
[System.ComponentModel.EditorBrowsable(System.ComponentModel.EditorBrowsableState.Never)]
[Obsolete("Use 'new GithubUpdateManager(...)' instead")]
public static async Task<UpdateManager> GitHubUpdateManager(
string repoUrl,
string applicationName = null,
string rootDirectory = null,
IFileDownloader urlDownloader = null,
bool prerelease = false,
string accessToken = null)
{
var repoUri = new Uri(repoUrl);
var userAgent = new ProductInfoHeaderValue("Squirrel", AssemblyRuntimeInfo.ExecutingAssemblyName.Version.ToString());
if (repoUri.Segments.Length != 3) {
throw new Exception("Repo URL must be to the root URL of the repo e.g. https://github.com/myuser/myrepo");
}
var releasesApiBuilder = new StringBuilder("repos")
.Append(repoUri.AbsolutePath)
.Append("/releases");
if (!string.IsNullOrWhiteSpace(accessToken))
releasesApiBuilder.Append("?access_token=").Append(accessToken);
Uri baseAddress;
if (repoUri.Host.EndsWith("github.com", StringComparison.OrdinalIgnoreCase)) {
baseAddress = new Uri("https://api.github.com/");
} else {
// if it's not github.com, it's probably an Enterprise server
// now the problem with Enterprise is that the API doesn't come prefixed
// it comes suffixed
// so the API path of http://internal.github.server.local API location is
// http://interal.github.server.local/api/v3.
baseAddress = new Uri(string.Format("{0}{1}{2}/api/v3/", repoUri.Scheme, Uri.SchemeDelimiter, repoUri.Host));
}
// above ^^ notice the end slashes for the baseAddress, explained here: http://stackoverflow.com/a/23438417/162694
using (var client = new HttpClient() { BaseAddress = baseAddress }) {
client.DefaultRequestHeaders.UserAgent.Add(userAgent);
var response = await client.GetAsync(releasesApiBuilder.ToString()).ConfigureAwait(false);
response.EnsureSuccessStatusCode();
var releases = SimpleJson.DeserializeObject<List<Release>>(await response.Content.ReadAsStringAsync().ConfigureAwait(false));
var latestRelease = releases
.Where(x => prerelease || !x.Prerelease)
.OrderByDescending(x => x.PublishedAt)
.First();
var latestReleaseUrl = latestRelease.HtmlUrl.Replace("/tag/", "/download/");
return new UpdateManager(latestReleaseUrl, applicationName, rootDirectory, urlDownloader);
}
}
}
}

View File

@@ -10,6 +10,7 @@ using Microsoft.Win32;
using Squirrel.NuGet;
using Squirrel.SimpleSplat;
using Squirrel.Shell;
using Squirrel.Sources;
namespace Squirrel
{
@@ -28,8 +29,8 @@ namespace Squirrel
/// <summary>True if the current executable is inside the target <see cref="AppDirectory"/>.</summary>
public bool IsInstalledApp => isUpdateExeAvailable() ? Utility.IsFileInDirectory(AssemblyRuntimeInfo.EntryExePath, AppDirectory) : false;
/// <summary>The url to use when checking for or downloading updates.</summary>
protected string _updateUrlOrPath;
/// <summary>The directory packages and temp files are stored in.</summary>
protected string PackagesDirectory => Utility.PackageDirectoryForAppDir(AppDirectory);
/// <summary>The application name provided in constructor, or null.</summary>
protected readonly string _applicationIdOverride;
@@ -37,8 +38,8 @@ namespace Squirrel
/// <summary>The path to the local app data folder on this machine.</summary>
protected readonly string _localAppDataDirectoryOverride;
/// <summary>The <see cref="IFileDownloader"/> to use when downloading data from the internet.</summary>
protected readonly IFileDownloader _urlDownloader;
/// <summary>The <see cref="IUpdateSource"/> responsible for retrieving updates from a package repository.</summary>
protected readonly IUpdateSource _updateSource;
private readonly object _lockobj = new object();
private IDisposable _updateLock;
@@ -46,7 +47,10 @@ namespace Squirrel
/// <summary>
/// Create a new instance of <see cref="UpdateManager"/> to check for and install updates.
/// Do not forget to dispose this class!
/// Do not forget to dispose this class! This constructor is just a shortcut for
/// <see cref="UpdateManager(IUpdateSource, string, string)"/>, and will automatically create
/// a <see cref="SimpleFileSource"/> or a <see cref="SimpleWebSource"/> depending on
/// whether 'urlOrPath' is a filepath or a URL, respectively.
/// </summary>
/// <param name="urlOrPath">
/// The URL where your update packages or stored, or a local package repository directory.
@@ -70,37 +74,47 @@ namespace Squirrel
string applicationIdOverride = null,
string localAppDataDirectoryOverride = null,
IFileDownloader urlDownloader = null)
: this(CreateSourceFromString(urlOrPath, urlDownloader), applicationIdOverride, localAppDataDirectoryOverride)
{ }
/// <summary>
/// Create a new instance of <see cref="UpdateManager"/> to check for and install updates.
/// Do not forget to dispose this class!
/// </summary>
/// <param name="updateSource">
/// The source of your update packages. This can be a web server (<see cref="SimpleWebSource"/>),
/// a local directory (<see cref="SimpleFileSource"/>), a GitHub repository (<see cref="GithubSource"/>),
/// or a custom location.
/// </param>
/// <param name="applicationIdOverride">
/// The Id of your application should correspond with the
/// appdata directory name, and the Id used with Squirrel releasify/pack.
/// If left null/empty, UpdateManger will attempt to determine the current application Id
/// from the installed app location, or throw if the app is not currently installed during certain
/// operations.
/// </param>
/// <param name="localAppDataDirectoryOverride">
/// Provide a custom location for the system LocalAppData, it will be used
/// instead of <see cref="Environment.SpecialFolder.LocalApplicationData"/>.
/// </param>
public UpdateManager(
IUpdateSource updateSource,
string applicationIdOverride = null,
string localAppDataDirectoryOverride = null)
{
_updateUrlOrPath = urlOrPath;
_updateSource = updateSource;
_applicationIdOverride = applicationIdOverride;
_localAppDataDirectoryOverride = localAppDataDirectoryOverride;
_urlDownloader = urlDownloader ?? Utility.CreateDefaultDownloader();
}
internal UpdateManager() { }
/// <summary>Clean up UpdateManager resources</summary>
~UpdateManager()
{
Dispose();
}
/// <inheritdoc/>
public virtual async Task<UpdateInfo> CheckForUpdate(bool ignoreDeltaUpdates = false, Action<int> progress = null, UpdaterIntention intention = UpdaterIntention.Update)
{
var checkForUpdate = new CheckForUpdateImpl(AppDirectory);
await acquireUpdateLock().ConfigureAwait(false);
return await checkForUpdate.CheckForUpdate(intention, Utility.LocalReleaseFileForAppDir(AppDirectory), _updateUrlOrPath, ignoreDeltaUpdates, progress, _urlDownloader).ConfigureAwait(false);
}
/// <inheritdoc/>
public virtual async Task DownloadReleases(IEnumerable<ReleaseEntry> releasesToDownload, Action<int> progress = null)
{
var downloadReleases = new DownloadReleasesImpl(AppDirectory);
await acquireUpdateLock().ConfigureAwait(false);
await downloadReleases.DownloadReleases(_updateUrlOrPath, releasesToDownload, progress, _urlDownloader).ConfigureAwait(false);
}
/// <inheritdoc/>
public async Task<string> ApplyReleases(UpdateInfo updateInfo, Action<int> progress = null)
{
@@ -315,6 +329,15 @@ namespace Squirrel
return Environment.GetFolderPath(Environment.SpecialFolder.LocalApplicationData);
}
private static IUpdateSource CreateSourceFromString(string urlOrPath, IFileDownloader urlDownloader)
{
if (Utility.IsHttpUrl(urlOrPath)) {
return new SimpleWebSource(urlOrPath, urlDownloader ?? Utility.CreateDefaultDownloader());
} else {
return new SimpleFileSource(new DirectoryInfo(urlOrPath));
}
}
private Task<IDisposable> acquireUpdateLock()
{
lock (_lockobj) {

View File

@@ -5,6 +5,7 @@ using System.Linq;
using System.Threading.Tasks;
using Squirrel;
using Squirrel.SimpleSplat;
using Squirrel.Sources;
namespace SquirrelCli.Sources
{
@@ -21,61 +22,36 @@ namespace SquirrelCli.Sources
public async Task DownloadRecentPackages()
{
var dl = Utility.CreateDefaultDownloader();
var releaseDirectoryInfo = new DirectoryInfo(_options.releaseDir);
if (!releaseDirectoryInfo.Exists)
releaseDirectoryInfo.Create();
var releases = await GithubUpdateManager.GetGithubReleases(
new Uri(_options.repoUrl), _options.token, _options.pre, dl);
if (String.IsNullOrWhiteSpace(_options.token))
Log.Warn("No GitHub access token provided. Unauthenticated requests will be limited to 60 per hour.");
if (!releases.Any()) {
Log.Warn("No github releases found.");
Log.Info("Fetching RELEASES...");
var source = new GithubSource(_options.repoUrl, _options.token, _options.pre);
var latestReleaseEntries = await source.GetReleaseFeed();
if (latestReleaseEntries == null || latestReleaseEntries.Length == 0) {
Log.Warn("No github release or assets found.");
return;
}
string bearer = null;
if (!string.IsNullOrWhiteSpace(_options.token))
bearer = "Bearer " + _options.token;
Log.Info($"Found {latestReleaseEntries.Length} assets in latest release ({source.Release.Name}).");
var lastRelease = await GetLastReleaseUrl(releases, dl, bearer);
if (lastRelease.Url == null) {
Log.Warn("No github releases found with a valid release attached.");
return;
}
Log.Info("Downloading package from " + lastRelease.Url);
var localFile = Path.Combine(releaseDirectoryInfo.FullName, lastRelease.Filename);
await dl.DownloadFile(lastRelease.Url, localFile, null, bearer);
var rf = ReleaseEntry.GenerateFromFile(localFile);
ReleaseEntry.WriteReleaseFile(new[] { rf }, Path.Combine(releaseDirectoryInfo.FullName, "RELEASES"));
}
private async Task<(string Url, string Filename)> GetLastReleaseUrl(IEnumerable<GithubUpdateManager.GithubRelease> releases, IFileDownloader dl, string bearer)
{
foreach (var r in releases) {
var releasesUrl = Utility.AppendPathToUri(new Uri(r.DownloadUrl), "RELEASES");
Log.Info("Downloading metadata from " + releasesUrl);
var releasesText = await dl.DownloadString(releasesUrl.ToString(), bearer);
var entries = ReleaseEntry.ParseReleaseFile(releasesText);
var latestAsset = entries
.Where(p => p.Version != null)
.Where(p => !p.IsDelta)
.OrderByDescending(p => p.Version)
.FirstOrDefault();
if (latestAsset != null) {
return (Utility.AppendPathToUri(new Uri(r.DownloadUrl), latestAsset.Filename).ToString(), latestAsset.Filename);
foreach (var entry in latestReleaseEntries) {
var localFile = Path.Combine(releaseDirectoryInfo.FullName, entry.Filename);
if (File.Exists(localFile)) {
Log.Info($"File '{entry.Filename}' exists on disk, skipping download.");
continue;
}
Log.Info($"Downloading {entry.Filename}...");
await source.DownloadReleaseEntry(entry, localFile, (p) => { });
}
return (null, null);
ReleaseEntry.BuildReleasesFile(releaseDirectoryInfo.FullName);
Log.Info("Done.");
}
public Task UploadMissingPackages()

View File

@@ -15,19 +15,19 @@ using Xunit;
namespace Squirrel.Tests
{
public class FakeUrlDownloader : IFileDownloader
public class FakeUrlDownloader : Sources.IFileDownloader
{
public Task<byte[]> DownloadBytes(string url, string auth)
public Task<byte[]> DownloadBytes(string url, string auth, string acc)
{
return Task.FromResult(new byte[0]);
}
public Task DownloadFile(string url, string targetFile, Action<int> progress, string auth)
public Task DownloadFile(string url, string targetFile, Action<int> progress, string auth, string acc)
{
return Task.CompletedTask;
}
public Task<string> DownloadString(string url, string auth)
public Task<string> DownloadString(string url, string auth, string acc)
{
return Task.FromResult("");
}

View File

@@ -336,7 +336,7 @@ namespace Squirrel.Tests
[Fact]
public void CurrentlyInstalledVersionDoesNotThrow()
{
using var fixture = new UpdateManager(null);
using var fixture = new UpdateManager();
Assert.Null(fixture.CurrentlyInstalledVersion());
Assert.False(fixture.IsInstalledApp);
}