1
0
mirror of https://github.com/bitwarden/server.git synced 2025-07-01 16:12:49 -05:00

Rewrite Icon fetching (#3023)

* Rewrite Icon fetching

* Move validation to IconUri, Uri, or UriBuilder

* `dotnet format` 🤖

* PR suggestions

* Add not null compiler hint

* Add twitter to test case

* Move Uri manipulation to UriService

* Implement MockedHttpClient

Presents better, fluent handling of message matching and response
building.

* Add redirect handling tests

* Add testing to models

* More aggressively dispose content in icon link

* Format 🤖

* Update icon lockfile

* Convert to cloned stream for HttpResponseBuilder

Content was being disposed when HttResponseMessage was being disposed.
This avoids losing our reference to our content and allows multiple
usages of the same `MockedHttpMessageResponse`

* Move services to extension

Extension is shared by testing and allows access to services from
our service tests

* Remove unused `using`

* Prefer awaiting asyncs for better exception handling

* `dotnet format` 🤖

* Await async

* Update tests to use test TLD and ip ranges

* Remove unused interfaces

* Make assignments static when possible

* Prefer invariant comparer to downcasing

* Prefer injecting interface services to implementations

* Prefer comparer set in HashSet initialization

* Allow SVG icons

* Filter out icons with unknown formats

* Seek to beginning of MemoryStream after writing it

* More appropriate to not return icon if it's invalid

* Add svg icon test
This commit is contained in:
Matt Gibson
2023-08-08 15:29:40 -04:00
committed by GitHub
parent ca368466ce
commit 4377c7a897
31 changed files with 1685 additions and 522 deletions

View File

@ -81,7 +81,7 @@ public class IconsController : Controller
}
else
{
icon = result.Icon;
icon = result;
}
// Only cache not found and smaller images (<= 50kb)

View File

@ -0,0 +1,100 @@
#nullable enable
using System.Collections;
using AngleSharp.Html.Parser;
using Bit.Icons.Extensions;
using Bit.Icons.Services;
namespace Bit.Icons.Models;
public class DomainIcons : IEnumerable<Icon>
{
private readonly ILogger<IIconFetchingService> _logger;
private readonly IHttpClientFactory _httpClientFactory;
private readonly IUriService _uriService;
private readonly List<Icon> _icons = new();
public string Domain { get; }
public Icon this[int i]
{
get
{
return _icons[i];
}
}
public IEnumerator<Icon> GetEnumerator() => ((IEnumerable<Icon>)_icons).GetEnumerator();
IEnumerator IEnumerable.GetEnumerator() => ((IEnumerable)_icons).GetEnumerator();
private DomainIcons(string domain, ILogger<IIconFetchingService> logger, IHttpClientFactory httpClientFactory, IUriService uriService)
{
_logger = logger;
_httpClientFactory = httpClientFactory;
_uriService = uriService;
Domain = domain;
}
public static async Task<DomainIcons> FetchAsync(string domain, ILogger<IIconFetchingService> logger, IHttpClientFactory httpClientFactory, IHtmlParser parser, IUriService uriService)
{
var pageIcons = new DomainIcons(domain, logger, httpClientFactory, uriService);
await pageIcons.FetchIconsAsync(parser);
return pageIcons;
}
private async Task FetchIconsAsync(IHtmlParser parser)
{
if (!Uri.TryCreate($"https://{Domain}", UriKind.Absolute, out var uri))
{
_logger.LogWarning("Bad domain: {domain}.", Domain);
return;
}
var host = uri.Host;
// first try https
using (var response = await IconHttpRequest.FetchAsync(uri, _logger, _httpClientFactory, _uriService))
{
if (response.IsSuccessStatusCode)
{
_icons.AddRange(await response.RetrieveIconsAsync(uri, Domain, parser));
return;
}
}
// then try http
uri = uri.ChangeScheme("http");
using (var response = await IconHttpRequest.FetchAsync(uri, _logger, _httpClientFactory, _uriService))
{
if (response.IsSuccessStatusCode)
{
_icons.AddRange(await response.RetrieveIconsAsync(uri, Domain, parser));
return;
}
}
var dotCount = Domain.Count(c => c == '.');
// Then try base domain
if (dotCount > 1 && DomainName.TryParseBaseDomain(Domain, out var baseDomain) &&
Uri.TryCreate($"https://{baseDomain}", UriKind.Absolute, out uri))
{
using var response = await IconHttpRequest.FetchAsync(uri, _logger, _httpClientFactory, _uriService);
if (response.IsSuccessStatusCode)
{
_icons.AddRange(await response.RetrieveIconsAsync(uri, Domain, parser));
return;
}
}
// Then try www
if (dotCount < 2 && Uri.TryCreate($"https://www.{host}", UriKind.Absolute, out uri))
{
using var response = await IconHttpRequest.FetchAsync(uri, _logger, _httpClientFactory, _uriService);
if (response.IsSuccessStatusCode)
{
_icons.AddRange(await response.RetrieveIconsAsync(uri, Domain, parser));
return;
}
}
}
}

View File

@ -0,0 +1,110 @@
#nullable enable
using System.Net;
using Bit.Icons.Extensions;
using Bit.Icons.Services;
namespace Bit.Icons.Models;
public class IconHttpRequest
{
private const int _maxRedirects = 2;
private static readonly HttpStatusCode[] _redirectStatusCodes = new HttpStatusCode[] { HttpStatusCode.Redirect, HttpStatusCode.MovedPermanently, HttpStatusCode.RedirectKeepVerb, HttpStatusCode.SeeOther };
private readonly ILogger<IIconFetchingService> _logger;
private readonly HttpClient _httpClient;
private readonly IHttpClientFactory _httpClientFactory;
private readonly IUriService _uriService;
private readonly int _redirectsCount;
private readonly Uri _uri;
private static HttpResponseMessage NotFound => new(HttpStatusCode.NotFound);
private IconHttpRequest(Uri uri, ILogger<IIconFetchingService> logger, IHttpClientFactory httpClientFactory, IUriService uriService, int redirectsCount)
{
_logger = logger;
_httpClientFactory = httpClientFactory;
_httpClient = _httpClientFactory.CreateClient("Icons");
_uriService = uriService;
_redirectsCount = redirectsCount;
_uri = uri;
}
public static async Task<IconHttpResponse> FetchAsync(Uri uri, ILogger<IIconFetchingService> logger, IHttpClientFactory httpClientFactory, IUriService uriService)
{
var pageIcons = new IconHttpRequest(uri, logger, httpClientFactory, uriService, 0);
var httpResponse = await pageIcons.FetchAsync();
return new IconHttpResponse(httpResponse, logger, httpClientFactory, uriService);
}
private async Task<HttpResponseMessage> FetchAsync()
{
if (!_uriService.TryGetUri(_uri, out var iconUri) || !iconUri!.IsValid)
{
return NotFound;
}
var response = await GetAsync(iconUri);
if (response.IsSuccessStatusCode)
{
return response;
}
using var responseForRedirect = response;
return await FollowRedirectsAsync(responseForRedirect, iconUri);
}
private async Task<HttpResponseMessage> GetAsync(IconUri iconUri)
{
using var message = new HttpRequestMessage();
message.RequestUri = iconUri.InnerUri;
message.Headers.Host = iconUri.Host;
message.Method = HttpMethod.Get;
try
{
return await _httpClient.SendAsync(message);
}
catch
{
return NotFound;
}
}
private async Task<HttpResponseMessage> FollowRedirectsAsync(HttpResponseMessage response, IconUri originalIconUri)
{
if (_redirectsCount >= _maxRedirects || response.Headers.Location == null ||
!_redirectStatusCodes.Contains(response.StatusCode))
{
return NotFound;
}
using var responseForRedirect = response;
var redirectUri = DetermineRedirectUri(responseForRedirect.Headers.Location, originalIconUri);
return await new IconHttpRequest(redirectUri, _logger, _httpClientFactory, _uriService, _redirectsCount + 1).FetchAsync();
}
private static Uri DetermineRedirectUri(Uri responseUri, IconUri originalIconUri)
{
if (responseUri.IsAbsoluteUri)
{
if (!responseUri.IsHypertext())
{
return responseUri.ChangeScheme("https");
}
return responseUri;
}
else
{
return new UriBuilder
{
Scheme = originalIconUri.Scheme,
Host = originalIconUri.Host,
Path = responseUri.ToString()
}.Uri;
}
}
}

View File

@ -0,0 +1,72 @@
#nullable enable
using System.Net;
using AngleSharp.Html.Parser;
using Bit.Icons.Services;
namespace Bit.Icons.Models;
public class IconHttpResponse : IDisposable
{
private const int _maxIconLinksProcessed = 200;
private const int _maxRetrievedIcons = 10;
private readonly HttpResponseMessage _response;
private readonly ILogger<IIconFetchingService> _logger;
private readonly IHttpClientFactory _httpClientFactory;
private readonly IUriService _uriService;
public HttpStatusCode StatusCode => _response.StatusCode;
public bool IsSuccessStatusCode => _response.IsSuccessStatusCode;
public string? ContentType => _response.Content.Headers.ContentType?.MediaType;
public HttpContent Content => _response.Content;
public IconHttpResponse(HttpResponseMessage response, ILogger<IIconFetchingService> logger, IHttpClientFactory httpClientFactory, IUriService uriService)
{
_response = response;
_logger = logger;
_httpClientFactory = httpClientFactory;
_uriService = uriService;
}
public async Task<IEnumerable<Icon>> RetrieveIconsAsync(Uri requestUri, string domain, IHtmlParser parser)
{
using var htmlStream = await _response.Content.ReadAsStreamAsync();
var head = await parser.ParseHeadAsync(htmlStream);
if (head == null)
{
_logger.LogWarning("No DocumentElement for {domain}.", domain);
return Array.Empty<Icon>();
}
// Make sure uri uses domain name, not ip
var uri = _response.RequestMessage?.RequestUri;
if (uri == null || IPAddress.TryParse(_response.RequestMessage!.RequestUri!.Host, out var _))
{
uri = requestUri;
}
var baseUrl = head.QuerySelector("base[href]")?.Attributes["href"]?.Value;
if (string.IsNullOrWhiteSpace(baseUrl))
{
baseUrl = "/";
}
var links = head.QuerySelectorAll("link[href]")
?.Take(_maxIconLinksProcessed)
.Select(l => new IconLink(l, uri, baseUrl))
.Where(l => l.IsUsable())
.OrderBy(l => l.Priority)
.Take(_maxRetrievedIcons)
.ToArray() ?? Array.Empty<IconLink>();
var results = await Task.WhenAll(links.Select(l => l.FetchAsync(_logger, _httpClientFactory, _uriService)));
return results.Where(r => r != null).Select(r => r!);
}
public void Dispose()
{
_response.Dispose();
}
}

View File

@ -0,0 +1,220 @@
#nullable enable
using System.Text;
using AngleSharp.Dom;
using Bit.Icons.Extensions;
using Bit.Icons.Services;
namespace Bit.Icons.Models;
public class IconLink
{
private static readonly HashSet<string> _iconRels = new(StringComparer.InvariantCultureIgnoreCase) { "icon", "apple-touch-icon", "shortcut icon" };
private static readonly HashSet<string> _blocklistedRels = new(StringComparer.InvariantCultureIgnoreCase) { "preload", "image_src", "preconnect", "canonical", "alternate", "stylesheet" };
private static readonly HashSet<string> _iconExtensions = new(StringComparer.InvariantCultureIgnoreCase) { ".ico", ".png", ".jpg", ".jpeg" };
private const string _pngMediaType = "image/png";
private static readonly byte[] _pngHeader = new byte[] { 137, 80, 78, 71 };
private static readonly byte[] _webpHeader = Encoding.UTF8.GetBytes("RIFF");
private const string _icoMediaType = "image/x-icon";
private const string _icoAltMediaType = "image/vnd.microsoft.icon";
private static readonly byte[] _icoHeader = new byte[] { 00, 00, 01, 00 };
private const string _jpegMediaType = "image/jpeg";
private static readonly byte[] _jpegHeader = new byte[] { 255, 216, 255 };
private const string _svgXmlMediaType = "image/svg+xml";
private static readonly HashSet<string> _allowedMediaTypes = new(StringComparer.InvariantCultureIgnoreCase)
{
_pngMediaType,
_icoMediaType,
_icoAltMediaType,
_jpegMediaType,
_svgXmlMediaType,
};
private bool _useUriDirectly = false;
private bool _validated = false;
private int? _width;
private int? _height;
public IAttr? Href { get; }
public IAttr? Rel { get; }
public IAttr? Type { get; }
public IAttr? Sizes { get; }
public Uri ParentUri { get; }
public string BaseUrlPath { get; }
public int Priority
{
get
{
if (_width == null || _width != _height)
{
return 200;
}
return _width switch
{
32 => 1,
64 => 2,
>= 24 and <= 128 => 3,
16 => 4,
_ => 100,
};
}
}
public IconLink(Uri parentPage)
{
_useUriDirectly = true;
_validated = true;
ParentUri = parentPage;
BaseUrlPath = parentPage.PathAndQuery;
}
public IconLink(IElement element, Uri parentPage, string baseUrlPath)
{
Href = element.Attributes["href"];
ParentUri = parentPage;
BaseUrlPath = baseUrlPath;
Rel = element.Attributes["rel"];
Type = element.Attributes["type"];
Sizes = element.Attributes["sizes"];
if (!string.IsNullOrWhiteSpace(Sizes?.Value))
{
var sizeParts = Sizes.Value.Split('x');
if (sizeParts.Length == 2 && int.TryParse(sizeParts[0].Trim(), out var width) &&
int.TryParse(sizeParts[1].Trim(), out var height))
{
_width = width;
_height = height;
}
}
}
public bool IsUsable()
{
if (string.IsNullOrWhiteSpace(Href?.Value))
{
return false;
}
if (Rel != null && _iconRels.Contains(Rel.Value))
{
_validated = true;
}
if (Rel == null || !_blocklistedRels.Contains(Rel.Value))
{
try
{
var extension = Path.GetExtension(Href.Value);
if (_iconExtensions.Contains(extension))
{
_validated = true;
}
}
catch (ArgumentException) { }
}
return _validated;
}
/// <summary>
/// Fetches the icon from the Href. Will always fail unless first validated with IsUsable().
/// </summary>
public async Task<Icon?> FetchAsync(ILogger<IIconFetchingService> logger, IHttpClientFactory httpClientFactory, IUriService uriService)
{
if (!_validated)
{
return null;
}
var uri = BuildUri();
if (uri == null)
{
return null;
}
using var response = await IconHttpRequest.FetchAsync(uri, logger, httpClientFactory, uriService);
if (!response.IsSuccessStatusCode)
{
return null;
}
var format = response.Content.Headers.ContentType?.MediaType;
var bytes = await response.Content.ReadAsByteArrayAsync();
response.Content.Dispose();
if (format == null || !_allowedMediaTypes.Contains(format))
{
format = DetermineImageFormatFromFile(bytes);
}
if (format == null || !_allowedMediaTypes.Contains(format))
{
return null;
}
return new Icon { Image = bytes, Format = format };
}
private Uri? BuildUri()
{
if (_useUriDirectly)
{
return ParentUri;
}
if (Href == null)
{
return null;
}
if (Href.Value.StartsWith("//") && Uri.TryCreate($"{ParentUri.Scheme}://{Href.Value[2..]}", UriKind.Absolute, out var uri))
{
return uri;
}
if (Uri.TryCreate(Href.Value, UriKind.Relative, out uri))
{
return new UriBuilder()
{
Scheme = ParentUri.Scheme,
Host = ParentUri.Host,
}.Uri.ConcatPath(BaseUrlPath, uri.OriginalString);
}
if (Uri.TryCreate(Href.Value, UriKind.Absolute, out uri))
{
return uri;
}
return null;
}
private static bool HeaderMatch(byte[] imageBytes, byte[] header)
{
return imageBytes.Length >= header.Length && header.SequenceEqual(imageBytes.Take(header.Length));
}
private static string DetermineImageFormatFromFile(byte[] imageBytes)
{
if (HeaderMatch(imageBytes, _icoHeader))
{
return _icoMediaType;
}
else if (HeaderMatch(imageBytes, _pngHeader) || HeaderMatch(imageBytes, _webpHeader))
{
return _pngMediaType;
}
else if (HeaderMatch(imageBytes, _jpegHeader))
{
return _jpegMediaType;
}
else
{
return string.Empty;
}
}
}

View File

@ -1,65 +0,0 @@
namespace Bit.Icons.Models;
public class IconResult
{
public IconResult(string href, string sizes)
{
Path = href;
if (!string.IsNullOrWhiteSpace(sizes))
{
var sizeParts = sizes.Split('x');
if (sizeParts.Length == 2 && int.TryParse(sizeParts[0].Trim(), out var width) &&
int.TryParse(sizeParts[1].Trim(), out var height))
{
DefinedWidth = width;
DefinedHeight = height;
if (width == height)
{
if (width == 32)
{
Priority = 1;
}
else if (width == 64)
{
Priority = 2;
}
else if (width >= 24 && width <= 128)
{
Priority = 3;
}
else if (width == 16)
{
Priority = 4;
}
else
{
Priority = 100;
}
}
}
}
if (Priority == 0)
{
Priority = 200;
}
}
public IconResult(Uri uri, byte[] bytes, string format)
{
Path = uri.ToString();
Icon = new Icon
{
Image = bytes,
Format = format
};
Priority = 10;
}
public string Path { get; set; }
public int? DefinedWidth { get; set; }
public int? DefinedHeight { get; set; }
public Icon Icon { get; set; }
public int Priority { get; set; }
}

View File

@ -0,0 +1,52 @@
#nullable enable
using System.Net;
using Bit.Icons.Extensions;
namespace Bit.Icons.Models;
public class IconUri
{
private readonly IPAddress _ip;
public string Host { get; }
public Uri InnerUri { get; }
public string Scheme => InnerUri.Scheme;
public bool IsValid
{
get
{
// Prevent direct access to any ip
if (IPAddress.TryParse(Host, out _))
{
return false;
}
// Prevent non-http(s) and non-default ports
if ((InnerUri.Scheme != "http" && InnerUri.Scheme != "https") || !InnerUri.IsDefaultPort)
{
return false;
}
// Prevent local hosts (localhost, bobs-pc, etc) and IP addresses
if (!Host.Contains('.') || _ip.IsInternal())
{
return false;
}
return true;
}
}
/// <summary>
/// Represents an ip-validated Uri for use in grabbing an icon.
/// </summary>
/// <param name="uriString"></param>
/// <param name="ip"></param>
public IconUri(Uri uri, IPAddress ip)
{
_ip = ip;
InnerUri = uri.ChangeHost(_ip.ToString());
Host = uri.Host;
}
}

View File

@ -1,8 +1,10 @@
using Bit.Icons.Models;
#nullable enable
using Bit.Icons.Models;
namespace Bit.Icons.Services;
public interface IIconFetchingService
{
Task<IconResult> GetIconAsync(string domain);
Task<Icon?> GetIconAsync(string domain);
}

View File

@ -0,0 +1,12 @@
#nullable enable
using Bit.Icons.Models;
namespace Bit.Icons.Services;
public interface IUriService
{
bool TryGetUri(string stringUri, out IconUri? iconUri);
bool TryGetUri(Uri uri, out IconUri? iconUri);
bool TryGetRedirect(HttpResponseMessage response, IconUri originalUri, out IconUri? iconUri);
}

View File

@ -1,449 +1,47 @@
using System.Net;
using System.Text;
#nullable enable
using AngleSharp.Html.Parser;
using Bit.Icons.Extensions;
using Bit.Icons.Models;
namespace Bit.Icons.Services;
public class IconFetchingService : IIconFetchingService
{
private readonly HashSet<string> _iconRels =
new HashSet<string> { "icon", "apple-touch-icon", "shortcut icon" };
private readonly HashSet<string> _blacklistedRels =
new HashSet<string> { "preload", "image_src", "preconnect", "canonical", "alternate", "stylesheet" };
private readonly HashSet<string> _iconExtensions =
new HashSet<string> { ".ico", ".png", ".jpg", ".jpeg" };
private readonly string _pngMediaType = "image/png";
private readonly byte[] _pngHeader = new byte[] { 137, 80, 78, 71 };
private readonly byte[] _webpHeader = Encoding.UTF8.GetBytes("RIFF");
private readonly string _icoMediaType = "image/x-icon";
private readonly string _icoAltMediaType = "image/vnd.microsoft.icon";
private readonly byte[] _icoHeader = new byte[] { 00, 00, 01, 00 };
private readonly string _jpegMediaType = "image/jpeg";
private readonly byte[] _jpegHeader = new byte[] { 255, 216, 255 };
private readonly HashSet<string> _allowedMediaTypes;
private readonly HttpClient _httpClient;
private readonly IHttpClientFactory _httpClientFactory;
private readonly ILogger<IIconFetchingService> _logger;
private readonly IHtmlParser _parser;
private readonly IUriService _uriService;
public IconFetchingService(ILogger<IIconFetchingService> logger)
public IconFetchingService(ILogger<IIconFetchingService> logger, IHttpClientFactory httpClientFactory, IHtmlParser parser, IUriService uriService)
{
_logger = logger;
_allowedMediaTypes = new HashSet<string>
_httpClientFactory = httpClientFactory;
_parser = parser;
_uriService = uriService;
}
public async Task<Icon?> GetIconAsync(string domain)
{
var domainIcons = await DomainIcons.FetchAsync(domain, _logger, _httpClientFactory, _parser, _uriService);
var result = domainIcons.Where(result => result != null).FirstOrDefault();
return result ?? await GetFaviconAsync(domain);
}
private async Task<Icon?> GetFaviconAsync(string domain)
{
// Fall back to favicon
var faviconUriBuilder = new UriBuilder
{
_pngMediaType,
_icoMediaType,
_icoAltMediaType,
_jpegMediaType
Scheme = "https",
Host = domain,
Path = "/favicon.ico"
};
_httpClient = new HttpClient(new HttpClientHandler
if (faviconUriBuilder.TryBuild(out var faviconUri))
{
AllowAutoRedirect = false,
AutomaticDecompression = DecompressionMethods.GZip | DecompressionMethods.Deflate,
});
_httpClient.Timeout = TimeSpan.FromSeconds(20);
_httpClient.MaxResponseContentBufferSize = 5000000; // 5 MB
}
public async Task<IconResult> GetIconAsync(string domain)
{
if (IPAddress.TryParse(domain, out _))
{
_logger.LogWarning("IP address: {0}.", domain);
return null;
return await new IconLink(faviconUri!).FetchAsync(_logger, _httpClientFactory, _uriService);
}
if (!Uri.TryCreate($"https://{domain}", UriKind.Absolute, out var parsedHttpsUri))
{
_logger.LogWarning("Bad domain: {0}.", domain);
return null;
}
var uri = parsedHttpsUri;
var response = await GetAndFollowAsync(uri, 2);
if ((response == null || !response.IsSuccessStatusCode) &&
Uri.TryCreate($"http://{parsedHttpsUri.Host}", UriKind.Absolute, out var parsedHttpUri))
{
Cleanup(response);
uri = parsedHttpUri;
response = await GetAndFollowAsync(uri, 2);
if (response == null || !response.IsSuccessStatusCode)
{
var dotCount = domain.Count(c => c == '.');
if (dotCount > 1 && DomainName.TryParseBaseDomain(domain, out var baseDomain) &&
Uri.TryCreate($"https://{baseDomain}", UriKind.Absolute, out var parsedBaseUri))
{
Cleanup(response);
uri = parsedBaseUri;
response = await GetAndFollowAsync(uri, 2);
}
else if (dotCount < 2 &&
Uri.TryCreate($"https://www.{parsedHttpsUri.Host}", UriKind.Absolute, out var parsedWwwUri))
{
Cleanup(response);
uri = parsedWwwUri;
response = await GetAndFollowAsync(uri, 2);
}
}
}
if (response?.Content == null || !response.IsSuccessStatusCode)
{
_logger.LogWarning("Couldn't load a website for {0}: {1}.", domain,
response?.StatusCode.ToString() ?? "null");
Cleanup(response);
return null;
}
var parser = new HtmlParser();
using (response)
using (var htmlStream = await response.Content.ReadAsStreamAsync())
using (var document = await parser.ParseDocumentAsync(htmlStream))
{
uri = response.RequestMessage.RequestUri;
if (document.DocumentElement == null)
{
_logger.LogWarning("No DocumentElement for {0}.", domain);
return null;
}
var baseUrl = "/";
var baseUrlNode = document.QuerySelector("head base[href]");
if (baseUrlNode != null)
{
var hrefAttr = baseUrlNode.Attributes["href"];
if (!string.IsNullOrWhiteSpace(hrefAttr?.Value))
{
baseUrl = hrefAttr.Value;
}
baseUrlNode = null;
hrefAttr = null;
}
var icons = new List<IconResult>();
var links = document.QuerySelectorAll("head link[href]");
if (links != null)
{
foreach (var link in links.Take(200))
{
var hrefAttr = link.Attributes["href"];
if (string.IsNullOrWhiteSpace(hrefAttr?.Value))
{
continue;
}
var relAttr = link.Attributes["rel"];
var sizesAttr = link.Attributes["sizes"];
if (relAttr != null && _iconRels.Contains(relAttr.Value.ToLower()))
{
icons.Add(new IconResult(hrefAttr.Value, sizesAttr?.Value));
}
else if (relAttr == null || !_blacklistedRels.Contains(relAttr.Value.ToLower()))
{
try
{
var extension = Path.GetExtension(hrefAttr.Value);
if (_iconExtensions.Contains(extension.ToLower()))
{
icons.Add(new IconResult(hrefAttr.Value, sizesAttr?.Value));
}
}
catch (ArgumentException) { }
}
sizesAttr = null;
relAttr = null;
hrefAttr = null;
}
links = null;
}
var iconResultTasks = new List<Task>();
foreach (var icon in icons.OrderBy(i => i.Priority).Take(10))
{
Uri iconUri = null;
if (icon.Path.StartsWith("//") && Uri.TryCreate($"{GetScheme(uri)}://{icon.Path.Substring(2)}",
UriKind.Absolute, out var slashUri))
{
iconUri = slashUri;
}
else if (Uri.TryCreate(icon.Path, UriKind.Relative, out var relUri))
{
iconUri = ResolveUri($"{GetScheme(uri)}://{uri.Host}", baseUrl, relUri.OriginalString);
}
else if (Uri.TryCreate(icon.Path, UriKind.Absolute, out var absUri))
{
iconUri = absUri;
}
if (iconUri != null)
{
var task = GetIconAsync(iconUri).ContinueWith(async (r) =>
{
var result = await r;
if (result != null)
{
icon.Path = iconUri.ToString();
icon.Icon = result.Icon;
}
});
iconResultTasks.Add(task);
}
}
await Task.WhenAll(iconResultTasks);
if (!icons.Any(i => i.Icon != null))
{
var faviconUri = ResolveUri($"{GetScheme(uri)}://{uri.Host}", "favicon.ico");
var result = await GetIconAsync(faviconUri);
if (result != null)
{
icons.Add(result);
}
else
{
_logger.LogWarning("No favicon.ico found for {0}.", uri.Host);
return null;
}
}
return icons.Where(i => i.Icon != null).OrderBy(i => i.Priority).First();
}
}
private async Task<IconResult> GetIconAsync(Uri uri)
{
using (var response = await GetAndFollowAsync(uri, 2))
{
if (response?.Content?.Headers == null || !response.IsSuccessStatusCode)
{
response?.Content?.Dispose();
return null;
}
var format = response.Content.Headers?.ContentType?.MediaType;
var bytes = await response.Content.ReadAsByteArrayAsync();
response.Content.Dispose();
if (format == null || !_allowedMediaTypes.Contains(format))
{
if (HeaderMatch(bytes, _icoHeader))
{
format = _icoMediaType;
}
else if (HeaderMatch(bytes, _pngHeader) || HeaderMatch(bytes, _webpHeader))
{
format = _pngMediaType;
}
else if (HeaderMatch(bytes, _jpegHeader))
{
format = _jpegMediaType;
}
else
{
return null;
}
}
return new IconResult(uri, bytes, format);
}
}
private async Task<HttpResponseMessage> GetAndFollowAsync(Uri uri, int maxRedirectCount)
{
var response = await GetAsync(uri);
if (response == null)
{
return null;
}
return await FollowRedirectsAsync(response, maxRedirectCount);
}
private async Task<HttpResponseMessage> GetAsync(Uri uri)
{
if (uri == null)
{
return null;
}
// Prevent non-http(s) and non-default ports
if ((uri.Scheme != "http" && uri.Scheme != "https") || !uri.IsDefaultPort)
{
return null;
}
// Prevent local hosts (localhost, bobs-pc, etc) and IP addresses
if (!uri.Host.Contains(".") || IPAddress.TryParse(uri.Host, out _))
{
return null;
}
// Resolve host to make sure it is not an internal/private IP address
try
{
var hostEntry = Dns.GetHostEntry(uri.Host);
if (hostEntry?.AddressList.Any(ip => IsInternal(ip)) ?? true)
{
return null;
}
}
catch
{
return null;
}
using (var message = new HttpRequestMessage())
{
message.RequestUri = uri;
message.Method = HttpMethod.Get;
// Let's add some headers to look like we're coming from a web browser request. Some websites
// will block our request without these.
message.Headers.Add("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 " +
"(KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.36 Edge/16.16299");
message.Headers.Add("Accept-Language", "en-US,en;q=0.8");
message.Headers.Add("Cache-Control", "no-cache");
message.Headers.Add("Pragma", "no-cache");
message.Headers.Add("Accept", "text/html,application/xhtml+xml,application/xml;" +
"q=0.9,image/webp,image/apng,*/*;q=0.8");
try
{
return await _httpClient.SendAsync(message);
}
catch
{
return null;
}
}
}
private async Task<HttpResponseMessage> FollowRedirectsAsync(HttpResponseMessage response,
int maxFollowCount, int followCount = 0)
{
if (response == null || response.IsSuccessStatusCode || followCount > maxFollowCount)
{
return response;
}
if (!(response.StatusCode == HttpStatusCode.Redirect ||
response.StatusCode == HttpStatusCode.MovedPermanently ||
response.StatusCode == HttpStatusCode.RedirectKeepVerb ||
response.StatusCode == HttpStatusCode.SeeOther) ||
response.Headers.Location == null)
{
Cleanup(response);
return null;
}
Uri location = null;
if (response.Headers.Location.IsAbsoluteUri)
{
if (response.Headers.Location.Scheme != "http" && response.Headers.Location.Scheme != "https")
{
if (Uri.TryCreate($"https://{response.Headers.Location.OriginalString}",
UriKind.Absolute, out var newUri))
{
location = newUri;
}
}
else
{
location = response.Headers.Location;
}
}
else
{
var requestUri = response.RequestMessage.RequestUri;
location = ResolveUri($"{GetScheme(requestUri)}://{requestUri.Host}",
response.Headers.Location.OriginalString);
}
Cleanup(response);
var newResponse = await GetAsync(location);
if (newResponse != null)
{
followCount++;
var redirectedResponse = await FollowRedirectsAsync(newResponse, maxFollowCount, followCount);
if (redirectedResponse != null)
{
if (redirectedResponse != newResponse)
{
Cleanup(newResponse);
}
return redirectedResponse;
}
}
return null;
}
private bool HeaderMatch(byte[] imageBytes, byte[] header)
{
return imageBytes.Length >= header.Length && header.SequenceEqual(imageBytes.Take(header.Length));
}
private Uri ResolveUri(string baseUrl, params string[] paths)
{
var url = baseUrl;
foreach (var path in paths)
{
if (Uri.TryCreate(new Uri(url), path, out var r))
{
url = r.ToString();
}
}
return new Uri(url);
}
private void Cleanup(IDisposable obj)
{
obj?.Dispose();
obj = null;
}
private string GetScheme(Uri uri)
{
return uri != null && uri.Scheme == "http" ? "http" : "https";
}
public static bool IsInternal(IPAddress ip)
{
if (IPAddress.IsLoopback(ip))
{
return true;
}
var ipString = ip.ToString();
if (ipString == "::1" || ipString == "::" || ipString.StartsWith("::ffff:"))
{
return true;
}
// IPv6
if (ip.AddressFamily == System.Net.Sockets.AddressFamily.InterNetworkV6)
{
return ipString.StartsWith("fc") || ipString.StartsWith("fd") ||
ipString.StartsWith("fe") || ipString.StartsWith("ff");
}
// IPv4
var bytes = ip.GetAddressBytes();
return (bytes[0]) switch
{
0 => true,
10 => true,
127 => true,
169 => bytes[1] == 254, // Cloud environments, such as AWS
172 => bytes[1] < 32 && bytes[1] >= 16,
192 => bytes[1] == 168,
_ => false,
};
}
}

View File

@ -0,0 +1,109 @@
#nullable enable
using System.Net;
using System.Net.Sockets;
using Bit.Icons.Extensions;
using Bit.Icons.Models;
namespace Bit.Icons.Services;
public class UriService : IUriService
{
public IconUri GetUri(string inputUri)
{
var uri = new Uri(inputUri);
return new IconUri(uri, DetermineIp(uri));
}
public bool TryGetUri(string stringUri, out IconUri? iconUri)
{
if (!Uri.TryCreate(stringUri, UriKind.Absolute, out var uri))
{
iconUri = null;
return false;
}
return TryGetUri(uri, out iconUri);
}
public IconUri GetUri(Uri uri)
{
return new IconUri(uri, DetermineIp(uri));
}
public bool TryGetUri(Uri uri, out IconUri? iconUri)
{
try
{
iconUri = GetUri(uri);
return true;
}
catch (Exception)
{
iconUri = null;
return false;
}
}
public IconUri GetRedirect(HttpResponseMessage response, IconUri originalUri)
{
if (response.Headers.Location == null)
{
throw new Exception("No redirect location found.");
}
var redirectUri = DetermineRedirectUri(response.Headers.Location, originalUri);
return new IconUri(redirectUri, DetermineIp(redirectUri));
}
public bool TryGetRedirect(HttpResponseMessage response, IconUri originalUri, out IconUri? iconUri)
{
try
{
iconUri = GetRedirect(response, originalUri);
return true;
}
catch (Exception)
{
iconUri = null;
return false;
}
}
private static Uri DetermineRedirectUri(Uri responseUri, IconUri originalIconUri)
{
if (responseUri.IsAbsoluteUri)
{
if (!responseUri.IsHypertext())
{
return responseUri.ChangeScheme("https");
}
return responseUri;
}
else
{
return new UriBuilder
{
Scheme = originalIconUri.Scheme,
Host = originalIconUri.Host,
Path = responseUri.ToString()
}.Uri;
}
}
private static IPAddress DetermineIp(Uri uri)
{
if (IPAddress.TryParse(uri.Host, out var ip))
{
return ip;
}
var hostEntry = Dns.GetHostEntry(uri.Host);
ip = hostEntry.AddressList.FirstOrDefault(ip => ip.AddressFamily == AddressFamily.InterNetwork || ip.IsIPv4MappedToIPv6)?.MapToIPv4();
if (ip == null)
{
throw new Exception($"Unable to determine IP for {uri.Host}");
}
return ip;
}
}

View File

@ -1,7 +1,7 @@
using System.Globalization;
using Bit.Core.Settings;
using Bit.Core.Utilities;
using Bit.Icons.Services;
using Bit.Icons.Extensions;
using Bit.SharedWeb.Utilities;
using Microsoft.Net.Http.Headers;
@ -30,6 +30,12 @@ public class Startup
ConfigurationBinder.Bind(Configuration.GetSection("IconsSettings"), iconsSettings);
services.AddSingleton(s => iconsSettings);
// Http client
services.ConfigureHttpClients();
// Add HtmlParser
services.AddHtmlParsing();
// Cache
services.AddMemoryCache(options =>
{
@ -37,8 +43,7 @@ public class Startup
});
// Services
services.AddSingleton<IDomainMappingService, DomainMappingService>();
services.AddSingleton<IIconFetchingService, IconFetchingService>();
services.AddServices();
// Mvc
services.AddMvc();

View File

@ -0,0 +1,42 @@
#nullable enable
using System.Net;
namespace Bit.Icons.Extensions;
public static class IPAddressExtension
{
public static bool IsInternal(this IPAddress ip)
{
if (IPAddress.IsLoopback(ip))
{
return true;
}
var ipString = ip.ToString();
if (ipString == "::1" || ipString == "::" || ipString.StartsWith("::ffff:"))
{
return true;
}
// IPv6
if (ip.AddressFamily == System.Net.Sockets.AddressFamily.InterNetworkV6)
{
return ipString.StartsWith("fc") || ipString.StartsWith("fd") ||
ipString.StartsWith("fe") || ipString.StartsWith("ff");
}
// IPv4
var bytes = ip.GetAddressBytes();
return (bytes[0]) switch
{
0 => true,
10 => true,
127 => true,
169 => bytes[1] == 254, // Cloud environments, such as AWS
172 => bytes[1] < 32 && bytes[1] >= 16,
192 => bytes[1] == 168,
_ => false,
};
}
}

View File

@ -0,0 +1,44 @@
# nullable enable
using System.Net;
using AngleSharp.Html.Parser;
using Bit.Icons.Services;
namespace Bit.Icons.Extensions;
public static class ServiceCollectionExtension
{
public static void ConfigureHttpClients(this IServiceCollection services)
{
services.AddHttpClient("Icons", client =>
{
client.Timeout = TimeSpan.FromSeconds(20);
client.MaxResponseContentBufferSize = 5000000; // 5 MB
// Let's add some headers to look like we're coming from a web browser request. Some websites
// will block our request without these.
client.DefaultRequestHeaders.Add("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 " +
"(KHTML, like Gecko) Chrome/114.0.0.0 Safari/537.36");
client.DefaultRequestHeaders.Add("Accept-Language", "en-US,en;q=0.8");
client.DefaultRequestHeaders.Add("Cache-Control", "no-cache");
client.DefaultRequestHeaders.Add("Pragma", "no-cache");
client.DefaultRequestHeaders.Add("Accept", "text/html,application/xhtml+xml,application/xml;" +
"q=0.9,image/webp,image/apng,*/*;q=0.8");
}).ConfigurePrimaryHttpMessageHandler(() => new HttpClientHandler
{
AllowAutoRedirect = false,
AutomaticDecompression = DecompressionMethods.GZip | DecompressionMethods.Deflate,
});
}
public static void AddHtmlParsing(this IServiceCollection services)
{
services.AddSingleton<IHtmlParser, HtmlParser>();
}
public static void AddServices(this IServiceCollection services)
{
services.AddSingleton<IUriService, UriService>();
services.AddSingleton<IDomainMappingService, DomainMappingService>();
services.AddSingleton<IIconFetchingService, IconFetchingService>();
}
}

View File

@ -0,0 +1,20 @@
#nullable enable
namespace Bit.Icons.Extensions;
public static class UriBuilderExtension
{
public static bool TryBuild(this UriBuilder builder, out Uri? uri)
{
try
{
uri = builder.Uri;
return true;
}
catch (UriFormatException)
{
uri = null;
return false;
}
}
}

View File

@ -0,0 +1,41 @@

#nullable enable
namespace Bit.Icons.Extensions;
public static class UriExtension
{
public static bool IsHypertext(this Uri uri)
{
return uri.Scheme == "http" || uri.Scheme == "https";
}
public static Uri ChangeScheme(this Uri uri, string scheme)
{
return new UriBuilder(scheme, uri.Host) { Path = uri.PathAndQuery }.Uri;
}
public static Uri ChangeHost(this Uri uri, string host)
{
return new UriBuilder(uri) { Host = host }.Uri;
}
public static Uri ConcatPath(this Uri uri, params string[] paths)
=> uri.ConcatPath(paths.AsEnumerable());
public static Uri ConcatPath(this Uri uri, IEnumerable<string> paths)
{
if (!paths.Any())
{
return uri;
}
if (Uri.TryCreate(uri, paths.First(), out var newUri))
{
return newUri.ConcatPath(paths.Skip(1));
}
else
{
return uri;
}
}
}