Note: Students should always aim to produce publication-worthy tables and figures. Unless otherwise stated, tables should be rendered using stargazer::(), while figures can be rendered using ggplot2::() or plot(). Regardless, tables and figures should always be presented with necessary formatting – e.g., (sub)title, axis (variable) labels and titles, a clearly-identifiable legend and key, etc. Problem sets must always be compiled using LaTex or RMarkdown and include the full coding routine (with notes explaining your implementation) used to complete each problem (10pts).


  1. Using the ag_news dataset from the textdata package in R (textdata::dataset_ag_news()), construct a Continuous Bag of Words (CBOW) model. With that model, use predict() to recover the 5 nearest words for a series of 5 other words that you choose.
ag <- textdata::dataset_ag_news() 
ag$title_clean <- reduce_complexity(ag$title)

cbow_model = word2vec(x = ag$title, # Ag Title Clean
                      type = "cbow", # Cont. Bag of Words 
                      dim = 15, # 15 Dimensions
                      iter = 20) # 20 Iterations (max)

cbow_lookslike <- predict(cbow_model, c("soccer", "champion", "war", "mars", "river"), type = "nearest", top_n = 5)
print(cbow_lookslike)
## $soccer
##    term1      term2 similarity rank
## 1 soccer     bronze  0.9598973    1
## 2 soccer      final  0.9452686    2
## 3 soccer qualifying  0.9443066    3
## 4 soccer     tennis  0.9433466    4
## 5 soccer  qualifier  0.9430767    5
## 
## $champion
##      term1     term2 similarity rank
## 1 champion defending  0.9639828    1
## 2 champion     champ  0.9475932    2
## 3 champion   Noguchi  0.9278772    3
## 4 champion   cycling  0.9270868    4
## 5 champion  guessing  0.9242694    5
## 
## $war
##   term1     term2 similarity rank
## 1   war elections  0.9542273    1
## 2   war    action  0.9511375    2
## 3   war       ban  0.9296306    3
## 4   war     state  0.9282848    4
## 5   war    warned  0.9273162    5
## 
## $mars
##   term1       term2 similarity rank
## 1  mars      resort  0.8856683    1
## 2  mars Bangladeshi  0.8798008    2
## 3  mars    Caucasus  0.8784330    3
## 4  mars    condemns  0.8760334    4
## 5  mars    reporter  0.8669539    5
## 
## $river
##   term1      term2 similarity rank
## 1 river     quakes  0.9323485    1
## 2 river      train  0.9317452    2
## 3 river earthquake  0.9278418    3
## 4 river landslides  0.9272598    4
## 5 river     trains  0.9270689    5
  1. Using SBERT and the ag_news dataset from the textdata package in R (textdata::dataset_ag_news()), recover both the embeddings from a sample of 1000 title. Afterwards, recover pairwise cosine similariy scores and report the top and bottom-5 most similar using stargazer (5pts).
set.seed(1234)

ag <- textdata::dataset_ag_news() %>%
  dplyr::sample_n(100) %>%
  mutate(row_id = row_number())

sentences_pairwise <- expand.grid(row_i = ag$row_id,
              row_j = ag$row_id) %>%
  filter(row_i != row_j) %>%  # remove self-pairs
  left_join(ag, by = c("row_i" = "row_id")) %>%
  rename(text_i = title) %>%
  left_join(ag, by = c("row_j" = "row_id")) %>%
  rename(text_j = title) %>%
  rename_with(~ gsub("\\.x$", "_i", .x)) %>%
  rename_with(~ gsub("\\.y$", "_j", .x)) 


library(reticulate) # Activate Reticulate
## Warning: package 'reticulate' was built under R version 4.4.3
virtualenv_create()# Create  Virtual Environment (If Needed)
## virtualenv: ~/.virtualenvs/r-reticulate
use_virtualenv(required = TRUE) # Activate Environment

required_packages <- c("torch", "transformers", "sentence-transformers")
for (pkg in required_packages) {
  if (!py_module_available(pkg)) {
    virtualenv_install(packages = pkg)
  }
} # Install torch, transformers, and sentence-transformers (if needed)
## Using virtual environment "~/.virtualenvs/r-reticulate" ...
## + "C:/Users/jaketruscott/Documents/.virtualenvs/r-reticulate/Scripts/python.exe" -m pip install --upgrade --no-user sentence-transformers
sentence_transformers <- import("sentence_transformers") # Import sentence-transformers
util <- import("sentence_transformers.util") # Util Package

model_name = "sentence-transformers/all-MiniLM-L6-v2" # SBERT
model <- sentence_transformers$SentenceTransformer(model_name) # Declare Model (SBERT)

embeddings <- model$encode(ag$title) # Convert Sentences to Embeddings
cosine_sim <- data.frame() # Output DF

for (i in 1:nrow(sentences_pairwise)){

  temp_row <- sentences_pairwise[i,] # Temp Row
  i_embeddings <- embeddings[temp_row$row_i,] # i embedding
  j_embeddings <- embeddings[temp_row$row_j,] # j embedding
  temp_similarity <- util$cos_sim(i_embeddings, j_embeddings)$item() # Cosine Sim

  cosine_sim <- bind_rows(cosine_sim,
                          data.frame(temp_row, similarity = temp_similarity)) # Export

  if (i %% 500 == 0){
    message('Completed ', i, '/', nrow(sentences_pairwise))
  }

}
## Completed 500/9900
## Completed 1000/9900
## Completed 1500/9900
## Completed 2000/9900
## Completed 2500/9900
## Completed 3000/9900
## Completed 3500/9900
## Completed 4000/9900
## Completed 4500/9900
## Completed 5000/9900
## Completed 5500/9900
## Completed 6000/9900
## Completed 6500/9900
## Completed 7000/9900
## Completed 7500/9900
## Completed 8000/9900
## Completed 8500/9900
## Completed 9000/9900
## Completed 9500/9900
top_5 <- cosine_sim %>%
  arrange(desc(similarity)) %>%
  slice_head(n = 3) # 5 Most Sim

top_5$text_i[1]
## [1] "Crude futures rally over \\$53"
top_5$text_j[1]
## [1] "Crude prices make a retreat"
top_5$similarity[1]
## [1] 0.6020842
bottom_5 <- cosine_sim %>%
  arrange(similarity) %>%
  slice_head(n = 3) %>%
  arrange(desc(similarity)) # 5 Least Sim

bottom_5$text_i[1]
## [1] "State of Emergency Declared in Iraq"
bottom_5$text_j[1]
## [1] "Slower Pub Sales Hit Whitbread"
bottom_5$similarity[1]
## [1] -0.1957191