using System; using System.Collections.Generic; using System.IO; using System.Linq; using System.Net; using System.Net.Http; using System.Threading.Tasks; using Bit.Icons.Models; using AngleSharp.Parser.Html; namespace Bit.Icons.Services { public class IconFetchingService : IIconFetchingService { private readonly HashSet _iconRels = new HashSet { "icon", "apple-touch-icon", "shortcut icon" }; private readonly HashSet _iconExtensions = new HashSet { ".ico", ".png", ".jpg", ".jpeg" }; private readonly string _pngMediaType = "image/png"; private readonly byte[] _pngHeader = new byte[] { 137, 80, 78, 71 }; private readonly string _icoMediaType = "image/x-icon"; private readonly string _icoAltMediaType = "image/vnd.microsoft.icon"; private readonly byte[] _icoHeader = new byte[] { 00, 00, 01, 00 }; private readonly string _jpegMediaType = "image/jpeg"; private readonly byte[] _jpegHeader = new byte[] { 255, 216, 255 }; private readonly HashSet _allowedMediaTypes; private readonly HttpClient _httpClient; public IconFetchingService() { _allowedMediaTypes = new HashSet { _pngMediaType, _icoMediaType, _icoAltMediaType, _jpegMediaType }; _httpClient = new HttpClient(new HttpClientHandler { AllowAutoRedirect = false, AutomaticDecompression = DecompressionMethods.GZip | DecompressionMethods.Deflate, }); _httpClient.Timeout = TimeSpan.FromSeconds(20); } public async Task GetIconAsync(string domain) { if(!Uri.TryCreate($"https://{domain}", UriKind.Absolute, out var parsedUri)) { return null; } var uri = parsedUri; var response = await GetAndFollowAsync(uri, 2); if(response == null || !response.IsSuccessStatusCode) { Cleanup(response); uri = new Uri($"http://{domain}"); response = await GetAndFollowAsync(uri, 2); if(response == null || !response.IsSuccessStatusCode) { Cleanup(response); uri = new Uri($"https://www.{domain}"); response = await GetAndFollowAsync(uri, 2); } } if(response?.Content == null || !response.IsSuccessStatusCode) { Cleanup(response); return null; } var parser = new HtmlParser(); using(response) using(var htmlStream = await response.Content.ReadAsStreamAsync()) using(var document = await parser.ParseAsync(htmlStream)) { uri = response.RequestMessage.RequestUri; if(document.DocumentElement == null) { return null; } var baseUrl = "/"; var baseUrlNode = document.QuerySelector("head base[href]"); if(baseUrlNode != null) { var hrefAttr = baseUrlNode.Attributes["href"]; if(!string.IsNullOrWhiteSpace(hrefAttr?.Value)) { baseUrl = hrefAttr.Value; } baseUrlNode = null; hrefAttr = null; } var icons = new List(); var links = document.QuerySelectorAll("head link[href]"); if(links != null) { foreach(var link in links.Take(200)) { var hrefAttr = link.Attributes["href"]; if(string.IsNullOrWhiteSpace(hrefAttr?.Value)) { continue; } var relAttr = link.Attributes["rel"]; var sizesAttr = link.Attributes["sizes"]; if(relAttr != null && _iconRels.Contains(relAttr.Value.ToLower())) { icons.Add(new IconResult(hrefAttr.Value, sizesAttr?.Value)); } else { try { var extension = Path.GetExtension(hrefAttr.Value); if(_iconExtensions.Contains(extension.ToLower())) { icons.Add(new IconResult(hrefAttr.Value, sizesAttr?.Value)); } } catch(ArgumentException) { } } sizesAttr = null; relAttr = null; hrefAttr = null; } links = null; } var iconResultTasks = new List(); foreach(var icon in icons.OrderBy(i => i.Priority).Take(10)) { Uri iconUri = null; if(icon.Path.StartsWith("//") && Uri.TryCreate($"{GetScheme(uri)}://{icon.Path.Substring(2)}", UriKind.Absolute, out var slashUri)) { iconUri = slashUri; } else if(Uri.TryCreate(icon.Path, UriKind.Relative, out var relUri)) { iconUri = ResolveUri($"{GetScheme(uri)}://{uri.Host}", baseUrl, relUri.OriginalString); } else if(Uri.TryCreate(icon.Path, UriKind.Absolute, out var absUri)) { iconUri = absUri; } if(iconUri != null) { var task = GetIconAsync(iconUri).ContinueWith(async (r) => { var result = await r; if(result != null) { icon.Path = iconUri.ToString(); icon.Icon = result.Icon; } }); iconResultTasks.Add(task); } } await Task.WhenAll(iconResultTasks); if(!icons.Any(i => i.Icon != null)) { var faviconUri = ResolveUri($"{GetScheme(uri)}://{uri.Host}", "favicon.ico"); var result = await GetIconAsync(faviconUri); if(result != null) { icons.Add(result); } else { return null; } } return icons.Where(i => i.Icon != null).OrderBy(i => i.Priority).First(); } } private async Task GetIconAsync(Uri uri) { using(var response = await GetAndFollowAsync(uri, 2)) { if(response?.Content?.Headers == null || !response.IsSuccessStatusCode) { response?.Content?.Dispose(); return null; } var format = response.Content.Headers?.ContentType?.MediaType; var bytes = await response.Content.ReadAsByteArrayAsync(); response.Content.Dispose(); if(format == null || !_allowedMediaTypes.Contains(format)) { if(HeaderMatch(bytes, _icoHeader)) { format = _icoMediaType; } else if(HeaderMatch(bytes, _pngHeader)) { format = _pngMediaType; } else if(HeaderMatch(bytes, _jpegHeader)) { format = _jpegMediaType; } else { return null; } } return new IconResult(uri, bytes, format); } } private async Task GetAndFollowAsync(Uri uri, int maxRedirectCount) { var response = await GetAsync(uri); if(response == null) { return null; } return await FollowRedirectsAsync(response, maxRedirectCount); } private async Task GetAsync(Uri uri) { using(var message = new HttpRequestMessage()) { message.RequestUri = uri; message.Method = HttpMethod.Get; // Let's add some headers to look like we're coming from a web browser request. Some websites // will block our request without these. message.Headers.Add("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 " + "(KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.36 Edge/16.16299"); message.Headers.Add("Accept-Language", "en-US,en;q=0.8"); message.Headers.Add("Cache-Control", "no-cache"); message.Headers.Add("Pragma", "no-cache"); message.Headers.Add("Accept", "text/html,application/xhtml+xml,application/xml;" + "q=0.9,image/webp,image/apng,*/*;q=0.8"); try { return await _httpClient.SendAsync(message); } catch { return null; } } } private async Task FollowRedirectsAsync(HttpResponseMessage response, int maxFollowCount, int followCount = 0) { if(response == null || response.IsSuccessStatusCode || followCount > maxFollowCount) { return response; } if(!(response.StatusCode == HttpStatusCode.Redirect || response.StatusCode == HttpStatusCode.MovedPermanently || response.StatusCode == HttpStatusCode.RedirectKeepVerb || response.StatusCode == HttpStatusCode.SeeOther) || response.Headers.Location == null) { Cleanup(response); return null; } Uri location = null; if(response.Headers.Location.IsAbsoluteUri) { if(response.Headers.Location.Scheme != "http" && response.Headers.Location.Scheme != "https") { if(Uri.TryCreate($"https://{response.Headers.Location.OriginalString}", UriKind.Absolute, out var newUri)) { location = newUri; } } else { location = response.Headers.Location; } } else { var requestUri = response.RequestMessage.RequestUri; location = ResolveUri($"{GetScheme(requestUri)}://{requestUri.Host}", response.Headers.Location.OriginalString); } Cleanup(response); var newResponse = await GetAsync(location); if(newResponse != null) { followCount++; var redirectedResponse = await FollowRedirectsAsync(newResponse, maxFollowCount, followCount); if(redirectedResponse != null) { if(redirectedResponse != newResponse) { Cleanup(newResponse); } return redirectedResponse; } } return null; } private bool HeaderMatch(byte[] imageBytes, byte[] header) { return imageBytes.Length >= header.Length && header.SequenceEqual(imageBytes.Take(header.Length)); } private Uri ResolveUri(string baseUrl, params string[] paths) { var url = baseUrl; foreach(var path in paths) { if(Uri.TryCreate(new Uri(url), path, out var r)) { url = r.ToString(); } } return new Uri(url); } private void Cleanup(IDisposable obj) { obj?.Dispose(); obj = null; } private string GetScheme(Uri uri) { return uri != null && uri.Scheme == "http" ? "http" : "https"; } } }